mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
92 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5e2eccda7 | ||
|
|
58b230ff5a | ||
|
|
67bdef13e7 | ||
|
|
e67dd93ee8 | ||
|
|
3140a60816 | ||
|
|
41c2ee9f83 | ||
|
|
ca748ee12a | ||
|
|
243b12778f | ||
|
|
43c81910ae | ||
|
|
a11199aa67 | ||
|
|
5f82a71d5f | ||
|
|
1a305db162 | ||
|
|
48a653dc63 | ||
|
|
f6ddcbc6cb | ||
|
|
98f13a6e5d | ||
|
|
613978a611 | ||
|
|
2b0e8a5a9f | ||
|
|
08bb05141c | ||
|
|
3ccaa5e103 | ||
|
|
09e42041ce | ||
|
|
a50e95f210 | ||
|
|
92b395d82a | ||
|
|
86abfbd394 | ||
|
|
a7f4093424 | ||
|
|
e33b1e2105 | ||
|
|
fff101e037 | ||
|
|
f1ac05b2e0 | ||
|
|
f115607779 | ||
|
|
1aef8b7155 | ||
|
|
0da949ba42 | ||
|
|
6b031085bd | ||
|
|
11b846dd49 | ||
|
|
b9e29c96bd | ||
|
|
4ac0ba570a | ||
|
|
d61a2c6cd0 | ||
|
|
1c301b4b61 | ||
|
|
24d0c35ed3 | ||
|
|
8aae7751dc | ||
|
|
725da7e887 | ||
|
|
bd9e9ec614 | ||
|
|
88589764b5 | ||
|
|
c659f3b058 | ||
|
|
80581a8364 | ||
|
|
3c046eb291 | ||
|
|
3e25488b2d | ||
|
|
57b17dc8e1 | ||
|
|
a20188ae59 | ||
|
|
c410be890e | ||
|
|
37d9863552 | ||
|
|
2f42ff9b47 | ||
|
|
914efc53e5 | ||
|
|
17e78ca382 | ||
|
|
7ba05366ed | ||
|
|
ca74f980c1 | ||
|
|
68f5d2ad08 | ||
|
|
2b3d730ca9 | ||
|
|
f28628e34b | ||
|
|
b6a027fd2b | ||
|
|
fb74fcf4a4 | ||
|
|
28b26dde94 | ||
|
|
d677978c90 | ||
|
|
a347c274b7 | ||
|
|
f79d8f0449 | ||
|
|
1bc48c55d5 | ||
|
|
9d0a31c0f1 | ||
|
|
9b086e39c6 | ||
|
|
5867e4d613 | ||
|
|
85f0d8353a | ||
|
|
f871717f68 | ||
|
|
f08e52dc86 | ||
|
|
500b345b3b | ||
|
|
995dd1b5f3 | ||
|
|
336114f217 | ||
|
|
866563ad25 | ||
|
|
e79928a815 | ||
|
|
1771ed3bef | ||
|
|
550fa5a319 | ||
|
|
8528dffbf2 | ||
|
|
8fbf6a4b09 | ||
|
|
239148596c | ||
|
|
a880d73481 | ||
|
|
80bfd64ffa | ||
|
|
0076ad2a1a | ||
|
|
edb3d322f0 | ||
|
|
9381057079 | ||
|
|
f21a36ca37 | ||
|
|
ee5382a064 | ||
|
|
b80e5ea987 | ||
|
|
3d4fcfacb6 | ||
|
|
32eac6d52e | ||
|
|
9762f4cde7 | ||
|
|
76901ba22f |
1
.agents/skills
Symbolic link
1
.agents/skills
Symbolic link
@@ -0,0 +1 @@
|
||||
../.claude/skills
|
||||
10
.claude/settings.json
Normal file
10
.claude/settings.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allowedTools": [
|
||||
"Read", "Grep", "Glob",
|
||||
"Bash(ls:*)", "Bash(cat:*)", "Bash(grep:*)", "Bash(find:*)",
|
||||
"Bash(git status:*)", "Bash(git diff:*)", "Bash(git log:*)", "Bash(git worktree:*)",
|
||||
"Bash(tmux:*)", "Bash(sleep:*)", "Bash(branchlet:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
106
.claude/skills/open-pr/SKILL.md
Normal file
106
.claude/skills/open-pr/SKILL.md
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
name: open-pr
|
||||
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
|
||||
user-invocable: true
|
||||
args: "[base-branch] — optional target branch (defaults to dev)."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Open a Pull Request
|
||||
|
||||
## Step 1: Pre-flight checks
|
||||
|
||||
Before opening the PR:
|
||||
|
||||
1. Ensure all changes are committed
|
||||
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
|
||||
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
|
||||
|
||||
## Step 2: Test coverage
|
||||
|
||||
**This is critical.** Before opening the PR, verify:
|
||||
|
||||
### Existing behavior is not broken
|
||||
- Identify which modules/components your changes touch
|
||||
- Run the existing test suites for those areas
|
||||
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
|
||||
|
||||
### New behavior has test coverage
|
||||
- Every new feature, endpoint, or behavior change needs tests
|
||||
- If you added a new block, add tests for that block
|
||||
- If you changed API behavior, add or update API tests
|
||||
- If you changed frontend behavior, verify it doesn't break existing flows
|
||||
|
||||
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
|
||||
|
||||
## Step 3: Create the PR using the repo template
|
||||
|
||||
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
|
||||
|
||||
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
|
||||
2. Preserve the exact section titles and formatting, including:
|
||||
- `### Why / What / How`
|
||||
- `### Changes 🏗️`
|
||||
- `### Checklist 📋`
|
||||
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
|
||||
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
|
||||
5. Do not alter the template structure, rename sections, or remove any checklist items
|
||||
|
||||
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
|
||||
|
||||
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${BASE_BRANCH:-dev}"
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
|
||||
PREOF
|
||||
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
|
||||
## Step 4: Review workflow
|
||||
|
||||
### If you have a workspace that allows testing (docker, running backend, etc.)
|
||||
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
|
||||
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
|
||||
|
||||
### If you do NOT have a workspace that allows testing
|
||||
This is common for agents running in worktrees without a full stack. In this case:
|
||||
|
||||
1. Run `/pr-review` locally to catch obvious issues before pushing
|
||||
2. **Comment `/review` on the PR** after creating it to trigger the review bot
|
||||
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
|
||||
4. Do NOT proceed or merge until the bot review comes back
|
||||
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
|
||||
|
||||
```bash
|
||||
# After creating the PR:
|
||||
PR_NUMBER=$(gh pr view --json number -q .number)
|
||||
gh pr comment "$PR_NUMBER" --body "/review"
|
||||
# Then use /pr-address to poll for and address the review when it arrives
|
||||
```
|
||||
|
||||
## Step 5: Address review feedback
|
||||
|
||||
Once the review bot or human reviewers leave comments:
|
||||
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
|
||||
- Do not merge without human approval.
|
||||
|
||||
## Related skills
|
||||
|
||||
| Skill | When to use |
|
||||
|---|---|
|
||||
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
|
||||
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
|
||||
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
|
||||
|
||||
## Step 6: Post-creation
|
||||
|
||||
After the PR is created and review is triggered:
|
||||
- Share the PR URL with the user
|
||||
- If waiting on the review bot, let the user know the expected wait time (~30 min)
|
||||
- Do not merge without human approval
|
||||
545
.claude/skills/orchestrate/SKILL.md
Normal file
545
.claude/skills/orchestrate/SKILL.md
Normal file
@@ -0,0 +1,545 @@
|
||||
---
|
||||
name: orchestrate
|
||||
description: "Meta-agent supervisor that manages a fleet of Claude Code agents running in tmux windows. Auto-discovers spare worktrees, spawns agents, monitors state, kicks idle agents, approves safe confirmations, and recycles worktrees when done. TRIGGER when user asks to supervise agents, run parallel tasks, manage worktrees, check agent status, or orchestrate parallel work."
|
||||
user-invocable: true
|
||||
argument-hint: "any free text — e.g. 'start 3 agents on X Y Z', 'show status', 'add task: implement feature A', 'stop', 'how many are free?'"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "6.0.0"
|
||||
---
|
||||
|
||||
# Orchestrate — Agent Fleet Supervisor
|
||||
|
||||
One tmux session, N windows — each window is one agent working in its own worktree. Speak naturally; Claude maps your intent to the right scripts.
|
||||
|
||||
## Scripts
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=$(git rev-parse --show-toplevel)/.claude/skills/orchestrate/scripts
|
||||
STATE_FILE=~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
| Script | Purpose |
|
||||
|---|---|
|
||||
| `find-spare.sh [REPO_ROOT]` | List free worktrees — one `PATH BRANCH` per line |
|
||||
| `spawn-agent.sh SESSION PATH SPARE NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]` | Create window + checkout branch + launch claude + send task. **Stdout: `SESSION:WIN` only** |
|
||||
| `recycle-agent.sh WINDOW PATH SPARE_BRANCH` | Kill window + restore spare branch |
|
||||
| `run-loop.sh` | **Mechanical babysitter** — idle restart + dialog approval + recycle on ORCHESTRATOR:DONE + supervisor health check + all-done notification |
|
||||
| `verify-complete.sh WINDOW` | Verify PR is done: checkpoints ✓ + 0 unresolved threads + CI green + no fresh CHANGES_REQUESTED. Repo auto-derived from state file `.repo` or git remote. |
|
||||
| `notify.sh MESSAGE` | Send notification via Discord webhook (env `DISCORD_WEBHOOK_URL` or state `.discord_webhook`), macOS notification center, and stdout |
|
||||
| `capacity.sh [REPO_ROOT]` | Print available + in-use worktrees |
|
||||
| `status.sh` | Print fleet status + live pane commands |
|
||||
| `poll-cycle.sh` | One monitoring cycle — classifies panes, tracks checkpoints, returns JSON action array |
|
||||
| `classify-pane.sh WINDOW` | Classify one pane state |
|
||||
|
||||
## Supervision model
|
||||
|
||||
```
|
||||
Orchestrating Claude (this Claude session — IS the supervisor)
|
||||
└── Reads pane output, checks CI, intervenes with targeted guidance
|
||||
run-loop.sh (separate tmux window, every 30s)
|
||||
└── Mechanical only: idle restart, dialog approval, recycle on ORCHESTRATOR:DONE
|
||||
```
|
||||
|
||||
**You (the orchestrating Claude)** are the supervisor. After spawning agents, stay in this conversation and actively monitor: poll each agent's pane every 2-3 minutes, check CI, nudge stalled agents, and verify completions. Do not spawn a separate supervisor Claude window — it loses context, is hard to observe, and compounds context compression problems.
|
||||
|
||||
**run-loop.sh** is the mechanical layer — zero tokens, handles things that need no judgment: restart crashed agents, press Enter on dialogs, recycle completed worktrees (only after `verify-complete.sh` passes).
|
||||
|
||||
## Checkpoint protocol
|
||||
|
||||
Agents output checkpoints as they complete each required step:
|
||||
|
||||
```
|
||||
CHECKPOINT:<step-name>
|
||||
```
|
||||
|
||||
Required steps are passed as args to `spawn-agent.sh` (e.g. `pr-address pr-test`). `run-loop.sh` will not recycle a window until all required checkpoints are found in the pane output. If `verify-complete.sh` fails, the agent is re-briefed automatically.
|
||||
|
||||
## Worktree lifecycle
|
||||
|
||||
```text
|
||||
spare/N branch → spawn-agent.sh (--session-id UUID) → window + feat/branch + claude running
|
||||
↓
|
||||
CHECKPOINT:<step> (as steps complete)
|
||||
↓
|
||||
ORCHESTRATOR:DONE
|
||||
↓
|
||||
verify-complete.sh: checkpoints ✓ + 0 threads + CI green + no fresh CHANGES_REQUESTED
|
||||
↓
|
||||
state → "done", notify, window KEPT OPEN
|
||||
↓
|
||||
user/orchestrator explicitly requests recycle
|
||||
↓
|
||||
recycle-agent.sh → spare/N (free again)
|
||||
```
|
||||
|
||||
**Windows are never auto-killed.** The worktree stays on its branch, the session stays alive. The agent is done working but the window, git state, and Claude session are all preserved until you choose to recycle.
|
||||
|
||||
**To resume a done or crashed session:**
|
||||
```bash
|
||||
# Resume by stored session ID (preferred — exact session, full context)
|
||||
claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
|
||||
# Or resume most recent session in that worktree directory
|
||||
cd /path/to/worktree && claude --continue --permission-mode bypassPermissions
|
||||
```
|
||||
|
||||
**To manually recycle when ready:**
|
||||
```bash
|
||||
bash ~/.claude/orchestrator/scripts/recycle-agent.sh SESSION:WIN WORKTREE_PATH spare/N
|
||||
# Then update state:
|
||||
jq --arg w "SESSION:WIN" '.agents |= map(if .window == $w then .state = "recycled" else . end)' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## State file (`~/.claude/orchestrator-state.json`)
|
||||
|
||||
Never committed to git. You maintain this file directly using `jq` + atomic writes (`.tmp` → `mv`).
|
||||
|
||||
```json
|
||||
{
|
||||
"active": true,
|
||||
"tmux_session": "autogpt1",
|
||||
"idle_threshold_seconds": 300,
|
||||
"loop_window": "autogpt1:5",
|
||||
"repo": "Significant-Gravitas/AutoGPT",
|
||||
"discord_webhook": "https://discord.com/api/webhooks/...",
|
||||
"last_poll_at": 0,
|
||||
"agents": [
|
||||
{
|
||||
"window": "autogpt1:3",
|
||||
"worktree": "AutoGPT6",
|
||||
"worktree_path": "/path/to/AutoGPT6",
|
||||
"spare_branch": "spare/6",
|
||||
"branch": "feat/my-feature",
|
||||
"objective": "Implement X and open a PR",
|
||||
"pr_number": "12345",
|
||||
"session_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"steps": ["pr-address", "pr-test"],
|
||||
"checkpoints": ["pr-address"],
|
||||
"state": "running",
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": 0,
|
||||
"spawned_at": 0,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Top-level optional fields:
|
||||
- `repo` — GitHub `owner/repo` for CI/thread checks. Auto-derived from git remote if omitted.
|
||||
- `discord_webhook` — Discord webhook URL for completion notifications. Also reads `DISCORD_WEBHOOK_URL` env var.
|
||||
|
||||
Per-agent fields:
|
||||
- `session_id` — UUID passed to `claude --session-id` at spawn; use with `claude --resume UUID` to restore exact session context after a crash or window close.
|
||||
- `last_rebriefed_at` — Unix timestamp of last re-brief; enforces 5-min cooldown to prevent spam.
|
||||
|
||||
Agent states: `running` | `idle` | `stuck` | `waiting_approval` | `complete` | `done` | `escalated`
|
||||
|
||||
`done` means verified complete — window is still open, session still alive, worktree still on task branch. Not recycled yet.
|
||||
|
||||
## Serial /pr-test rule
|
||||
|
||||
`/pr-test` and `/pr-test --fix` run local Docker + integration tests that use shared ports, a shared database, and shared build caches. **Running two `/pr-test` jobs simultaneously will cause port conflicts and database corruption.**
|
||||
|
||||
**Rule: only one `/pr-test` runs at a time. The orchestrator serializes them.**
|
||||
|
||||
You (the orchestrating Claude) own the test queue:
|
||||
1. Agents do `pr-review` and `pr-address` in parallel — that's safe (they only push code and reply to GitHub).
|
||||
2. When a PR needs local testing, add it to your mental queue — don't give agents a `pr-test` step.
|
||||
3. Run `/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER --fix` yourself, sequentially.
|
||||
4. Feed results back to the relevant agent via `tmux send-keys`:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "Local tests for PR #N: <paste failure output or 'all passed'>. Fix any failures and push, then output ORCHESTRATOR:DONE."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
5. Wait for CI to confirm green before marking the agent done.
|
||||
|
||||
If multiple PRs need testing at the same time, pick the one furthest along (fewest pending CI checks) and test it first. Only start the next test after the previous one completes.
|
||||
|
||||
## Session restore (tested and confirmed)
|
||||
|
||||
Agent sessions are saved to disk. To restore a closed or crashed session:
|
||||
|
||||
```bash
|
||||
# If session_id is in state (preferred):
|
||||
NEW_WIN=$(tmux new-window -t SESSION -n WORKTREE_NAME -P -F '#{window_index}')
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --resume SESSION_ID --permission-mode bypassPermissions" Enter
|
||||
|
||||
# If no session_id (use --continue for most recent session in that directory):
|
||||
tmux send-keys -t "SESSION:${NEW_WIN}" "cd /path/to/worktree && claude --continue --permission-mode bypassPermissions" Enter
|
||||
```
|
||||
|
||||
`--continue` restores the full conversation history including all tool calls, file edits, and context. The agent resumes exactly where it left off. After restoring, update the window address in the state file:
|
||||
|
||||
```bash
|
||||
jq --arg old "SESSION:OLD_WIN" --arg new "SESSION:NEW_WIN" \
|
||||
'(.agents[] | select(.window == $old)).window = $new' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## Intent → action mapping
|
||||
|
||||
Match the user's message to one of these intents:
|
||||
|
||||
| The user says something like… | What to do |
|
||||
|---|---|
|
||||
| "status", "what's running", "show agents" | Run `status.sh` + `capacity.sh`, show output |
|
||||
| "how many free", "capacity", "available worktrees" | Run `capacity.sh`, show output |
|
||||
| "start N agents on X, Y, Z" or "run these tasks: …" | See **Spawning agents** below |
|
||||
| "add task: …", "add one more agent for …" | See **Adding an agent** below |
|
||||
| "stop", "shut down", "pause the fleet" | See **Stopping** below |
|
||||
| "poll", "check now", "run a cycle" | Run `poll-cycle.sh`, process actions |
|
||||
| "recycle window X", "free up autogpt3" | Run `recycle-agent.sh` directly |
|
||||
|
||||
When the intent is ambiguous, show capacity first and ask what tasks to run.
|
||||
|
||||
## Spawning agents
|
||||
|
||||
### 1. Resolve tmux session
|
||||
|
||||
```bash
|
||||
tmux list-sessions -F "#{session_name}: #{session_windows} windows" 2>/dev/null
|
||||
```
|
||||
|
||||
Use an existing session. **Never create a tmux session from within Claude** — it becomes a child of Claude's process and dies when the session ends. If no session exists, tell the user to run `tmux new-session -d -s autogpt1` in their terminal first, then re-invoke `/orchestrate`.
|
||||
|
||||
### 2. Show available capacity
|
||||
|
||||
```bash
|
||||
bash $SKILLS_DIR/capacity.sh $(git rev-parse --show-toplevel)
|
||||
```
|
||||
|
||||
### 3. Collect tasks from the user
|
||||
|
||||
For each task, gather:
|
||||
- **objective** — what to do (e.g. "implement feature X and open a PR")
|
||||
- **branch name** — e.g. `feat/my-feature` (derive from objective if not given)
|
||||
- **pr_number** — GitHub PR number if working on an existing PR (for verification)
|
||||
- **steps** — required checkpoint names in order (e.g. `pr-address pr-test`) — derive from objective
|
||||
|
||||
Ask for `idle_threshold_seconds` only if the user mentions it (default: 300).
|
||||
|
||||
Never ask the user to specify a worktree — auto-assign from `find-spare.sh`.
|
||||
|
||||
### 4. Spawn one agent per task
|
||||
|
||||
```bash
|
||||
# Get ordered list of spare worktrees
|
||||
SPARE_LIST=$(bash $SKILLS_DIR/find-spare.sh $(git rev-parse --show-toplevel))
|
||||
|
||||
# For each task, take the next spare line:
|
||||
WORKTREE_PATH=$(echo "$SPARE_LINE" | awk '{print $1}')
|
||||
SPARE_BRANCH=$(echo "$SPARE_LINE" | awk '{print $2}')
|
||||
|
||||
# With PR number and required steps:
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE" "$PR_NUMBER" "pr-address" "pr-test")
|
||||
|
||||
# Without PR (new work):
|
||||
WINDOW=$(bash $SKILLS_DIR/spawn-agent.sh "$SESSION" "$WORKTREE_PATH" "$SPARE_BRANCH" "$NEW_BRANCH" "$OBJECTIVE")
|
||||
```
|
||||
|
||||
Build an agent record and append it to the state file. If the state file doesn't exist yet, initialize it:
|
||||
|
||||
```bash
|
||||
# Derive repo from git remote (used by verify-complete.sh + supervisor)
|
||||
REPO=$(git remote get-url origin 2>/dev/null | sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
|
||||
jq -n \
|
||||
--arg session "$SESSION" \
|
||||
--arg repo "$REPO" \
|
||||
--argjson threshold 300 \
|
||||
'{active:true, tmux_session:$session, idle_threshold_seconds:$threshold,
|
||||
repo:$repo, loop_window:null, supervisor_window:null, last_poll_at:0, agents:[]}' \
|
||||
> ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
Optionally add a Discord webhook for completion notifications:
|
||||
```bash
|
||||
jq --arg hook "$DISCORD_WEBHOOK_URL" '.discord_webhook = $hook' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
`spawn-agent.sh` writes the initial agent record (window, worktree_path, branch, objective, state, etc.) to the state file automatically — **do not append the record again after calling it.** The record already exists and `pr_number`/`steps` are patched in by the script itself.
|
||||
|
||||
### 5. Start the mechanical babysitter
|
||||
|
||||
```bash
|
||||
LOOP_WIN=$(tmux new-window -t "$SESSION" -n "orchestrator" -P -F '#{window_index}')
|
||||
LOOP_WINDOW="${SESSION}:${LOOP_WIN}"
|
||||
tmux send-keys -t "$LOOP_WINDOW" "bash $SKILLS_DIR/run-loop.sh" Enter
|
||||
|
||||
jq --arg w "$LOOP_WINDOW" '.loop_window = $w' ~/.claude/orchestrator-state.json \
|
||||
> /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
### 6. Begin supervising directly in this conversation
|
||||
|
||||
You are the supervisor. After spawning, immediately start your first poll loop (see **Supervisor duties** below) and continue every 2-3 minutes. Do NOT spawn a separate supervisor Claude window.
|
||||
|
||||
## Adding an agent
|
||||
|
||||
Find the next spare worktree, then spawn and append to state — same as steps 2–4 above but for a single task. If no spare worktrees are available, tell the user.
|
||||
|
||||
## Supervisor duties (YOUR job, every 2-3 min in this conversation)
|
||||
|
||||
You are the supervisor. Run this poll loop directly in your Claude session — not in a separate window.
|
||||
|
||||
### Poll loop mechanism
|
||||
|
||||
You are reactive — you only act when a tool completes or the user sends a message. To create a self-sustaining poll loop without user involvement:
|
||||
|
||||
1. Start each poll with `run_in_background: true` + a sleep before the work:
|
||||
```bash
|
||||
sleep 120 && tmux capture-pane -t autogpt1:0 -p -S -200 | tail -40
|
||||
# + similar for each active window
|
||||
```
|
||||
2. When the background job notifies you, read the pane output and take action.
|
||||
3. Immediately schedule the next background poll — this keeps the loop alive.
|
||||
4. Stop scheduling when all agents are done/escalated.
|
||||
|
||||
**Never tell the user "I'll poll every 2-3 minutes"** — that does nothing without a trigger. Start the background job instead.
|
||||
|
||||
### Each poll: what to check
|
||||
|
||||
```bash
|
||||
# 1. Read state
|
||||
cat ~/.claude/orchestrator-state.json | jq '.agents[] | {window, worktree, branch, state, pr_number, checkpoints}'
|
||||
|
||||
# 2. For each running/stuck/idle agent, capture pane
|
||||
tmux capture-pane -t SESSION:WIN -p -S -200 | tail -60
|
||||
```
|
||||
|
||||
For each agent, decide:
|
||||
|
||||
| What you see | Action |
|
||||
|---|---|
|
||||
| Spinner / tools running | Do nothing — agent is working |
|
||||
| Idle `❯` prompt, no `ORCHESTRATOR:DONE` | Stalled — send specific nudge with objective from state |
|
||||
| Stuck in error loop | Send targeted fix with exact error + solution |
|
||||
| Waiting for input / question | Answer and unblock via `tmux send-keys` |
|
||||
| CI red | `gh pr checks PR_NUMBER --repo REPO` → tell agent exactly what's failing |
|
||||
| Context compacted / agent lost | Send recovery: `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="WIN")'` + `gh pr view PR_NUMBER --json title,body` |
|
||||
| `ORCHESTRATOR:DONE` in output | Run `verify-complete.sh` — if it fails, re-brief with specific reason |
|
||||
|
||||
### Strict ORCHESTRATOR:DONE gate
|
||||
|
||||
`verify-complete.sh` handles the main checks automatically (checkpoints, threads, CI green, spawned_at, and CHANGES_REQUESTED). Run it:
|
||||
|
||||
**CHANGES_REQUESTED staleness rule**: a `CHANGES_REQUESTED` review only blocks if it was submitted *after* the latest commit. If the latest commit postdates the review, the review is considered stale (feedback already addressed) and does not block. This avoids false negatives when a bot reviewer hasn't re-reviewed after the agent's fixing commits.
|
||||
|
||||
```bash
|
||||
SKILLS_DIR=~/.claude/orchestrator/scripts
|
||||
bash $SKILLS_DIR/verify-complete.sh SESSION:WIN
|
||||
```
|
||||
|
||||
If it passes → run-loop.sh will recycle the window automatically. No manual action needed.
|
||||
If it fails → re-brief the agent with the failure reason. Never manually mark state `done` to bypass this.
|
||||
|
||||
### Re-brief a stalled agent
|
||||
|
||||
```bash
|
||||
OBJ=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .objective' ~/.claude/orchestrator-state.json)
|
||||
PR=$(jq -r --arg w SESSION:WIN '.agents[] | select(.window==$w) | .pr_number' ~/.claude/orchestrator-state.json)
|
||||
tmux send-keys -t SESSION:WIN "You appear stalled. Your objective: $OBJ. Check: gh pr view $PR --json title,body,headRefName to reorient."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
|
||||
If `image_path` is set on the agent record, include: "Re-read context at IMAGE_PATH with the Read tool."
|
||||
|
||||
## Self-recovery protocol (agents)
|
||||
|
||||
spawn-agent.sh automatically includes this instruction in every objective:
|
||||
|
||||
> If your context compacts and you lose track of what to do, run:
|
||||
> `cat ~/.claude/orchestrator-state.json | jq '.agents[] | select(.window=="SESSION:WIN")'`
|
||||
> and `gh pr view PR_NUMBER --json title,body,headRefName` to reorient.
|
||||
> Output each completed step as `CHECKPOINT:<step-name>` on its own line.
|
||||
|
||||
## Passing images and screenshots to agents
|
||||
|
||||
`tmux send-keys` is text-only — you cannot paste a raw image into a pane. To give an agent visual context (screenshots, diagrams, mockups):
|
||||
|
||||
1. **Save the image to a temp file** with a stable path:
|
||||
```bash
|
||||
# If the user drags in a screenshot or you receive a file path:
|
||||
IMAGE_PATH="/tmp/orchestrator-context-$(date +%s).png"
|
||||
cp "$USER_PROVIDED_PATH" "$IMAGE_PATH"
|
||||
```
|
||||
|
||||
2. **Reference the path in the objective string**:
|
||||
```bash
|
||||
OBJECTIVE="Implement the layout shown in /tmp/orchestrator-context-1234567890.png. Read that image first with the Read tool to understand the design."
|
||||
```
|
||||
|
||||
3. The agent uses its `Read` tool to view the image at startup — Claude Code agents are multimodal and can read image files directly.
|
||||
|
||||
**Rule**: always use `/tmp/orchestrator-context-<timestamp>.png` as the naming convention so the supervisor knows what to look for if it needs to re-brief an agent with the same image.
|
||||
|
||||
---
|
||||
|
||||
## Orchestrator final evaluation (YOU decide, not the script)
|
||||
|
||||
`verify-complete.sh` is a gate — it blocks premature marking. But it cannot tell you if the work is actually good. That is YOUR job.
|
||||
|
||||
When run-loop marks an agent `pending_evaluation` and you're notified, do all of these before marking done:
|
||||
|
||||
### 1. Run /pr-test (required, serialized, use TodoWrite to queue)
|
||||
|
||||
`/pr-test` is the only reliable confirmation that the objective is actually met. Run it yourself, not the agent.
|
||||
|
||||
**When multiple PRs reach `pending_evaluation` at the same time, use TodoWrite to queue them:**
|
||||
```
|
||||
- [ ] /pr-test PR #12636 — fix copilot retry logic
|
||||
- [ ] /pr-test PR #12699 — builder chat panel
|
||||
```
|
||||
Run one at a time. Check off as you go.
|
||||
|
||||
```
|
||||
/pr-test https://github.com/Significant-Gravitas/AutoGPT/pull/PR_NUMBER
|
||||
```
|
||||
|
||||
**/pr-test can be lazy** — if it gives vague output, re-run with full context:
|
||||
|
||||
```
|
||||
/pr-test https://github.com/OWNER/REPO/pull/PR_NUMBER
|
||||
Context: This PR implements <objective from state file>. Key files: <list>.
|
||||
Please verify: <specific behaviors to check>.
|
||||
```
|
||||
|
||||
Only one `/pr-test` at a time — they share ports and DB.
|
||||
|
||||
### /pr-test result evaluation
|
||||
|
||||
**PARTIAL on any headline feature scenario is an immediate blocker.** Do not approve, do not mark done, do not let the agent output `ORCHESTRATOR:DONE`.
|
||||
|
||||
| `/pr-test` result | Action |
|
||||
|---|---|
|
||||
| All headline scenarios **PASS** | Proceed to evaluation step 2 |
|
||||
| Any headline scenario **PARTIAL** | Re-brief the agent immediately — see below |
|
||||
| Any headline scenario **FAIL** | Re-brief the agent immediately |
|
||||
|
||||
**What PARTIAL means**: the feature is only partly working. Example: the Apply button never appeared, or the AI returned no action blocks. The agent addressed part of the objective but not all of it.
|
||||
|
||||
**When any headline scenario is PARTIAL or FAIL:**
|
||||
|
||||
1. Do NOT mark the agent done or accept `ORCHESTRATOR:DONE`
|
||||
2. Re-brief the agent with the specific scenario that failed and what was missing:
|
||||
```bash
|
||||
tmux send-keys -t SESSION:WIN "PARTIAL result on /pr-test — S5 (Apply button) never appeared. The AI must output JSON action blocks for the Apply button to render. Fix this before re-running /pr-test."
|
||||
sleep 0.3
|
||||
tmux send-keys -t SESSION:WIN Enter
|
||||
```
|
||||
3. Set state back to `running`:
|
||||
```bash
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "running"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
4. Wait for new `ORCHESTRATOR:DONE`, then re-run `/pr-test` from scratch
|
||||
|
||||
**Rule: only ALL-PASS qualifies for approval.** A mix of PASS + PARTIAL is a failure.
|
||||
|
||||
> **Why this matters**: PR #12699 was wrongly approved with S5 PARTIAL — the AI never output JSON action blocks so the Apply button never appeared. The fix was already in the agent's reach but slipped through because PARTIAL was not treated as blocking.
|
||||
|
||||
### 2. Do your own evaluation
|
||||
|
||||
1. **Read the PR diff and objective** — does the code actually implement what was asked? Is anything obviously missing or half-done?
|
||||
2. **Read the resolved threads** — were comments addressed with real fixes, or just dismissed/resolved without changes?
|
||||
3. **Check CI run names** — any suspicious retries that shouldn't have passed?
|
||||
4. **Check the PR description** — title, summary, test plan complete?
|
||||
|
||||
### 3. Decide
|
||||
|
||||
- `/pr-test` all scenarios PASS + evaluation looks good → mark `done` in state, tell the user the PR is ready, ask if window should be closed
|
||||
- `/pr-test` any scenario PARTIAL or FAIL → re-brief the agent with the specific failing scenario, set state back to `running` (see `/pr-test result evaluation` above)
|
||||
- Evaluation finds gaps even with all PASS → re-brief the agent with specific gaps, set state back to `running`
|
||||
|
||||
**Never mark done based purely on script output.** You hold the full objective context; the script does not.
|
||||
|
||||
```bash
|
||||
# Mark done after your positive evaluation:
|
||||
jq --arg w "SESSION:WIN" '(.agents[] | select(.window == $w)).state = "done"' \
|
||||
~/.claude/orchestrator-state.json > /tmp/orch.tmp && mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
```
|
||||
|
||||
## When to stop the fleet
|
||||
|
||||
Stop the fleet (`active = false`) when **all** of the following are true:
|
||||
|
||||
| Check | How to verify |
|
||||
|---|---|
|
||||
| All agents are `done` or `escalated` | `jq '[.agents[] | select(.state | test("running\|stuck\|idle\|waiting_approval"))] | length' ~/.claude/orchestrator-state.json` == 0 |
|
||||
| All PRs have 0 unresolved review threads | GraphQL `isResolved` check per PR |
|
||||
| All PRs have green CI **on a run triggered after the agent's last push** | `gh run list --branch BRANCH --limit 1` timestamp > `spawned_at` in state |
|
||||
| No fresh CHANGES_REQUESTED (after latest commit) | `verify-complete.sh` checks this — stale pre-commit reviews are ignored |
|
||||
| No agents are `escalated` without human review | If any are escalated, surface to user first |
|
||||
|
||||
**Do NOT stop just because agents output `ORCHESTRATOR:DONE`.** That is a signal to verify, not a signal to stop.
|
||||
|
||||
**Do stop** if the user explicitly says "stop", "shut down", or "kill everything", even with agents still running.
|
||||
|
||||
```bash
|
||||
# Graceful stop
|
||||
jq '.active = false' ~/.claude/orchestrator-state.json > /tmp/orch.tmp \
|
||||
&& mv /tmp/orch.tmp ~/.claude/orchestrator-state.json
|
||||
|
||||
LOOP_WINDOW=$(jq -r '.loop_window // ""' ~/.claude/orchestrator-state.json)
|
||||
[ -n "$LOOP_WINDOW" ] && tmux kill-window -t "$LOOP_WINDOW" 2>/dev/null || true
|
||||
```
|
||||
|
||||
Does **not** recycle running worktrees — agents may still be mid-task. Run `capacity.sh` to see what's still in progress.
|
||||
|
||||
## tmux send-keys pattern
|
||||
|
||||
**Always split long messages into text + Enter as two separate calls with a sleep between them.** If sent as one call (`"text" Enter`), Enter can fire before the full string is buffered into Claude's input — leaving the message stuck as `[Pasted text +N lines]` unsent.
|
||||
|
||||
```bash
|
||||
# CORRECT — text then Enter separately
|
||||
tmux send-keys -t "$WINDOW" "your long message here"
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# WRONG — Enter may fire before text is buffered
|
||||
tmux send-keys -t "$WINDOW" "your long message here" Enter
|
||||
```
|
||||
|
||||
Short single-character sends (`y`, `Down`, empty Enter for dialog approval) are safe to combine since they have no buffering lag.
|
||||
|
||||
---
|
||||
|
||||
## Protected worktrees
|
||||
|
||||
Some worktrees must **never** be used as spare worktrees for agent tasks because they host files critical to the orchestrator itself:
|
||||
|
||||
| Worktree | Protected branch | Why |
|
||||
|---|---|---|
|
||||
| `AutoGPT1` | `dx/orchestrate-skill` | Hosts the orchestrate skill scripts. `recycle-agent.sh` would check out `spare/1`, wiping `.claude/skills/` and breaking all subsequent `spawn-agent.sh` calls. |
|
||||
|
||||
**Rule**: when selecting spare worktrees via `find-spare.sh`, skip any worktree whose CURRENT branch matches a protected branch. If you accidentally spawn an agent in a protected worktree, do not let `recycle-agent.sh` run on it — manually restore the branch after the agent finishes.
|
||||
|
||||
When `dx/orchestrate-skill` is merged into `dev`, `AutoGPT1` becomes a normal spare again.
|
||||
|
||||
---
|
||||
|
||||
## Key rules
|
||||
|
||||
1. **Scripts do all the heavy lifting** — don't reimplement their logic inline in this file
|
||||
2. **Never ask the user to pick a worktree** — auto-assign from `find-spare.sh` output
|
||||
3. **Never restart a running agent** — only restart on `idle` kicks (foreground is a shell)
|
||||
4. **Auto-dismiss settings dialogs** — if "Enter to confirm" appears, send Down+Enter
|
||||
5. **Always `--permission-mode bypassPermissions`** on every spawn
|
||||
6. **Escalate after 3 kicks** — mark `escalated`, surface to user
|
||||
7. **Atomic state writes** — always write to `.tmp` then `mv`
|
||||
8. **Never approve destructive commands** outside the worktree scope — when in doubt, escalate
|
||||
9. **Never recycle without verification** — `verify-complete.sh` must pass before recycling
|
||||
10. **No TASK.md files** — commit risk; use state file + `gh pr view` for agent context persistence
|
||||
11. **Re-brief stalled agents** — read objective from state file + `gh pr view`, send via tmux
|
||||
12. **ORCHESTRATOR:DONE is a signal to verify, not to accept** — always run `verify-complete.sh` and check CI run timestamp before recycling
|
||||
13. **Protected worktrees** — never use the worktree hosting the skill scripts as a spare
|
||||
14. **Images via file path** — save screenshots to `/tmp/orchestrator-context-<ts>.png`, pass path in objective; agents read with the `Read` tool
|
||||
15. **Split send-keys** — always separate text and Enter with `sleep 0.3` between calls for long strings
|
||||
43
.claude/skills/orchestrate/scripts/capacity.sh
Executable file
43
.claude/skills/orchestrate/scripts/capacity.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# capacity.sh — show fleet capacity: available spare worktrees + in-use agents
|
||||
#
|
||||
# Usage: capacity.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree of the current git repo.
|
||||
#
|
||||
# Reads: ~/.claude/orchestrator-state.json (skipped if missing or corrupt)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
|
||||
echo "=== Available (spare) worktrees ==="
|
||||
if [ -n "$REPO_ROOT" ]; then
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" "$REPO_ROOT" 2>/dev/null || echo "")
|
||||
else
|
||||
SPARE=$("$SCRIPTS_DIR/find-spare.sh" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$SPARE" ]; then
|
||||
echo " (none)"
|
||||
else
|
||||
while IFS= read -r line; do
|
||||
[ -z "$line" ] && continue
|
||||
echo " ✓ $line"
|
||||
done <<< "$SPARE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== In-use worktrees ==="
|
||||
if [ -f "$STATE_FILE" ] && jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
IN_USE=$(jq -r '.agents[] | select(.state != "done") | " [\(.state)] \(.worktree_path) → \(.branch)"' \
|
||||
"$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -n "$IN_USE" ]; then
|
||||
echo "$IN_USE"
|
||||
else
|
||||
echo " (none)"
|
||||
fi
|
||||
else
|
||||
echo " (no active state file)"
|
||||
fi
|
||||
85
.claude/skills/orchestrate/scripts/classify-pane.sh
Executable file
85
.claude/skills/orchestrate/scripts/classify-pane.sh
Executable file
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env bash
|
||||
# classify-pane.sh — Classify the current state of a tmux pane
|
||||
#
|
||||
# Usage: classify-pane.sh <tmux-target>
|
||||
# tmux-target: e.g. "work:0", "work:1.0"
|
||||
#
|
||||
# Output (stdout): JSON object:
|
||||
# { "state": "running|idle|waiting_approval|complete", "reason": "...", "pane_cmd": "..." }
|
||||
#
|
||||
# Exit codes: 0=ok, 1=error (invalid target or tmux window not found)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TARGET="${1:-}"
|
||||
|
||||
if [ -z "$TARGET" ]; then
|
||||
echo '{"state":"error","reason":"no target provided","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate tmux target format: session:window or session:window.pane
|
||||
if ! [[ "$TARGET" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo '{"state":"error","reason":"invalid tmux target format","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check session exists (use %%:* to extract session name from session:window)
|
||||
if ! tmux list-windows -t "${TARGET%%:*}" &>/dev/null 2>&1; then
|
||||
echo '{"state":"error","reason":"tmux target not found","pane_cmd":""}'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the current foreground command in the pane
|
||||
PANE_CMD=$(tmux display-message -t "$TARGET" -p '#{pane_current_command}' 2>/dev/null || echo "unknown")
|
||||
|
||||
# Capture and strip ANSI codes (use perl for cross-platform compatibility — BSD sed lacks \x1b support)
|
||||
RAW=$(tmux capture-pane -t "$TARGET" -p -S -50 2>/dev/null || echo "")
|
||||
CLEAN=$(echo "$RAW" | perl -pe 's/\x1b\[[0-9;]*[a-zA-Z]//g; s/\x1b\(B//g; s/\x1b\[\?[0-9]*[hl]//g; s/\r//g' \
|
||||
| grep -v '^[[:space:]]*$' || true)
|
||||
|
||||
# --- Check: explicit completion marker ---
|
||||
# Must be on its own line (not buried in the objective text sent at spawn time).
|
||||
if echo "$CLEAN" | grep -qE "^[[:space:]]*ORCHESTRATOR:DONE[[:space:]]*$"; then
|
||||
jq -n --arg cmd "$PANE_CMD" '{"state":"complete","reason":"ORCHESTRATOR:DONE marker found","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Check: Claude Code approval prompt patterns ---
|
||||
LAST_40=$(echo "$CLEAN" | tail -40)
|
||||
APPROVAL_PATTERNS=(
|
||||
"Do you want to proceed"
|
||||
"Do you want to make this"
|
||||
"\\[y/n\\]"
|
||||
"\\[Y/n\\]"
|
||||
"\\[n/Y\\]"
|
||||
"Proceed\\?"
|
||||
"Allow this command"
|
||||
"Run bash command"
|
||||
"Allow bash"
|
||||
"Would you like"
|
||||
"Press enter to continue"
|
||||
"Esc to cancel"
|
||||
)
|
||||
for pattern in "${APPROVAL_PATTERNS[@]}"; do
|
||||
if echo "$LAST_40" | grep -qiE "$pattern"; then
|
||||
jq -n --arg pattern "$pattern" --arg cmd "$PANE_CMD" \
|
||||
'{"state":"waiting_approval","reason":"approval pattern: \($pattern)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# --- Check: shell prompt (claude has exited) ---
|
||||
# If the foreground process is a shell (not claude/node), the agent has exited
|
||||
case "$PANE_CMD" in
|
||||
zsh|bash|fish|sh|dash|tcsh|ksh)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"idle","reason":"agent exited — shell prompt active","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
# Agent is still running (claude/node/python is the foreground process)
|
||||
jq -n --arg cmd "$PANE_CMD" \
|
||||
'{"state":"running","reason":"foreground process: \($cmd)","pane_cmd":$cmd}'
|
||||
exit 0
|
||||
24
.claude/skills/orchestrate/scripts/find-spare.sh
Executable file
24
.claude/skills/orchestrate/scripts/find-spare.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
# find-spare.sh — list worktrees on spare/N branches (free to use)
|
||||
#
|
||||
# Usage: find-spare.sh [REPO_ROOT]
|
||||
# REPO_ROOT defaults to the root worktree containing the current git repo.
|
||||
#
|
||||
# Output (stdout): one line per available worktree: "PATH BRANCH"
|
||||
# e.g.: /Users/me/Code/AutoGPT3 spare/3
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="${1:-$(git rev-parse --show-toplevel 2>/dev/null || echo "")}"
|
||||
if [ -z "$REPO_ROOT" ]; then
|
||||
echo "Error: not inside a git repo and no REPO_ROOT provided" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
git -C "$REPO_ROOT" worktree list --porcelain \
|
||||
| awk '
|
||||
/^worktree / { path = substr($0, 10) }
|
||||
/^branch / { branch = substr($0, 8); print path " " branch }
|
||||
' \
|
||||
| { grep -E " refs/heads/spare/[0-9]+$" || true; } \
|
||||
| sed 's|refs/heads/||'
|
||||
40
.claude/skills/orchestrate/scripts/notify.sh
Executable file
40
.claude/skills/orchestrate/scripts/notify.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
# notify.sh — send a fleet notification message
|
||||
#
|
||||
# Delivery order (first available wins):
|
||||
# 1. Discord webhook — DISCORD_WEBHOOK_URL env var OR state file .discord_webhook
|
||||
# 2. macOS notification center — osascript (silent fail if unavailable)
|
||||
# 3. Stdout only
|
||||
#
|
||||
# Usage: notify.sh MESSAGE
|
||||
# Exit: always 0 (notification failure must not abort the caller)
|
||||
|
||||
MESSAGE="${1:-}"
|
||||
[ -z "$MESSAGE" ] && exit 0
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# --- Resolve Discord webhook ---
|
||||
WEBHOOK="${DISCORD_WEBHOOK_URL:-}"
|
||||
if [ -z "$WEBHOOK" ] && [ -f "$STATE_FILE" ]; then
|
||||
WEBHOOK=$(jq -r '.discord_webhook // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# --- Discord delivery ---
|
||||
if [ -n "$WEBHOOK" ]; then
|
||||
PAYLOAD=$(jq -n --arg msg "$MESSAGE" '{"content": $msg}')
|
||||
curl -s -X POST "$WEBHOOK" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD" > /dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
# --- macOS notification center (silent if not macOS or osascript missing) ---
|
||||
if command -v osascript &>/dev/null 2>&1; then
|
||||
# Escape single quotes for AppleScript
|
||||
SAFE_MSG=$(echo "$MESSAGE" | sed "s/'/\\\\'/g")
|
||||
osascript -e "display notification \"${SAFE_MSG}\" with title \"Orchestrator\"" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Always print to stdout so run-loop.sh logs it
|
||||
echo "$MESSAGE"
|
||||
exit 0
|
||||
257
.claude/skills/orchestrate/scripts/poll-cycle.sh
Executable file
257
.claude/skills/orchestrate/scripts/poll-cycle.sh
Executable file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env bash
|
||||
# poll-cycle.sh — Single orchestrator poll cycle
|
||||
#
|
||||
# Reads ~/.claude/orchestrator-state.json, classifies each agent, updates state,
|
||||
# and outputs a JSON array of actions for Claude to take.
|
||||
#
|
||||
# Usage: poll-cycle.sh
|
||||
# Output (stdout): JSON array of action objects
|
||||
# [{ "window": "work:0", "action": "kick|approve|none", "state": "...",
|
||||
# "worktree": "...", "objective": "...", "reason": "..." }]
|
||||
#
|
||||
# The state file is updated in-place (atomic write via .tmp).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CLASSIFY="$SCRIPTS_DIR/classify-pane.sh"
|
||||
|
||||
# Cross-platform md5: always outputs just the hex digest
|
||||
md5_hash() {
|
||||
if command -v md5sum &>/dev/null; then
|
||||
md5sum | awk '{print $1}'
|
||||
else
|
||||
md5 | awk '{print $NF}'
|
||||
fi
|
||||
}
|
||||
|
||||
# Clean up temp file on any exit (avoids stale .tmp if jq write fails)
|
||||
trap 'rm -f "${STATE_FILE}.tmp"' EXIT
|
||||
|
||||
# Ensure state file exists
|
||||
if [ ! -f "$STATE_FILE" ]; then
|
||||
echo '{"active":false,"agents":[]}' > "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Validate JSON upfront before any jq reads that run under set -e.
|
||||
# A truncated/corrupt file (e.g. from a SIGKILL mid-write) would otherwise
|
||||
# abort the script at the ACTIVE read below without emitting any JSON output.
|
||||
if ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIVE=$(jq -r '.active // false' "$STATE_FILE")
|
||||
if [ "$ACTIVE" != "true" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NOW=$(date +%s)
|
||||
IDLE_THRESHOLD=$(jq -r '.idle_threshold_seconds // 300' "$STATE_FILE")
|
||||
|
||||
ACTIONS="[]"
|
||||
UPDATED_AGENTS="[]"
|
||||
|
||||
# Read agents as newline-delimited JSON objects.
|
||||
# jq exits non-zero when .agents[] has no matches on an empty array, which is valid —
|
||||
# so we suppress that exit code and separately validate the file is well-formed JSON.
|
||||
if ! AGENTS_JSON=$(jq -e -c '.agents // empty | .[]' "$STATE_FILE" 2>/dev/null); then
|
||||
if ! jq -e '.' "$STATE_FILE" > /dev/null 2>&1; then
|
||||
echo "State file parse error — check $STATE_FILE" >&2
|
||||
fi
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -z "$AGENTS_JSON" ]; then
|
||||
echo "[]"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
while IFS= read -r agent; do
|
||||
[ -z "$agent" ] && continue
|
||||
|
||||
# Use // "" defaults so a single malformed field doesn't abort the whole cycle
|
||||
WINDOW=$(echo "$agent" | jq -r '.window // ""')
|
||||
WORKTREE=$(echo "$agent" | jq -r '.worktree // ""')
|
||||
OBJECTIVE=$(echo "$agent"| jq -r '.objective // ""')
|
||||
STATE=$(echo "$agent" | jq -r '.state // "running"')
|
||||
LAST_HASH=$(echo "$agent"| jq -r '.last_output_hash // ""')
|
||||
IDLE_SINCE=$(echo "$agent"| jq -r '.idle_since // 0')
|
||||
REVISION_COUNT=$(echo "$agent"| jq -r '.revision_count // 0')
|
||||
|
||||
# Validate window format to prevent tmux target injection.
|
||||
# Allow session:window (numeric or named) and session:window.pane
|
||||
if ! [[ "$WINDOW" =~ ^[a-zA-Z0-9_.-]+:[a-zA-Z0-9_.-]+(\.[0-9]+)?$ ]]; then
|
||||
echo "Skipping agent with invalid window value: $WINDOW" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Pass-through terminal-state agents
|
||||
if [[ "$STATE" == "done" || "$STATE" == "escalated" || "$STATE" == "complete" || "$STATE" == "pending_evaluation" ]]; then
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
fi
|
||||
|
||||
# Classify pane.
|
||||
# classify-pane.sh always emits JSON before exit (even on error), so using
|
||||
# "|| echo '...'" would concatenate two JSON objects when it exits non-zero.
|
||||
# Use "|| true" inside the substitution so set -euo pipefail does not abort
|
||||
# the poll cycle when classify exits with a non-zero status code.
|
||||
CLASSIFICATION=$("$CLASSIFY" "$WINDOW" 2>/dev/null || true)
|
||||
[ -z "$CLASSIFICATION" ] && CLASSIFICATION='{"state":"error","reason":"classify failed","pane_cmd":"unknown"}'
|
||||
|
||||
PANE_STATE=$(echo "$CLASSIFICATION" | jq -r '.state')
|
||||
PANE_REASON=$(echo "$CLASSIFICATION" | jq -r '.reason')
|
||||
|
||||
# Capture full pane output once — used for hash (stuck detection) and checkpoint parsing.
|
||||
# Use -S -500 to get the last ~500 lines of scrollback so checkpoints aren't missed.
|
||||
RAW=$(tmux capture-pane -t "$WINDOW" -p -S -500 2>/dev/null || echo "")
|
||||
|
||||
# --- Checkpoint tracking ---
|
||||
# Parse any "CHECKPOINT:<step>" lines the agent has output and merge into state file.
|
||||
# The agent writes these as it completes each required step so verify-complete.sh can gate recycling.
|
||||
EXISTING_CPS=$(echo "$agent" | jq -c '.checkpoints // []')
|
||||
NEW_CHECKPOINTS_JSON="$EXISTING_CPS"
|
||||
if [ -n "$RAW" ]; then
|
||||
FOUND_CPS=$(echo "$RAW" \
|
||||
| grep -oE "CHECKPOINT:[a-zA-Z0-9_-]+" \
|
||||
| sed 's/CHECKPOINT://' \
|
||||
| sort -u \
|
||||
| jq -R . | jq -s . 2>/dev/null || echo "[]")
|
||||
NEW_CHECKPOINTS_JSON=$(jq -n \
|
||||
--argjson existing "$EXISTING_CPS" \
|
||||
--argjson found "$FOUND_CPS" \
|
||||
'($existing + $found) | unique' 2>/dev/null || echo "$EXISTING_CPS")
|
||||
fi
|
||||
|
||||
# Compute content hash for stuck-detection (only for running agents)
|
||||
CURRENT_HASH=""
|
||||
if [[ "$PANE_STATE" == "running" ]] && [ -n "$RAW" ]; then
|
||||
CURRENT_HASH=$(echo "$RAW" | tail -20 | md5_hash)
|
||||
fi
|
||||
|
||||
NEW_STATE="$STATE"
|
||||
NEW_IDLE_SINCE="$IDLE_SINCE"
|
||||
NEW_REVISION_COUNT="$REVISION_COUNT"
|
||||
ACTION="none"
|
||||
REASON="$PANE_REASON"
|
||||
|
||||
case "$PANE_STATE" in
|
||||
complete)
|
||||
# Agent output ORCHESTRATOR:DONE — mark pending_evaluation so orchestrator handles it.
|
||||
# run-loop does NOT verify or notify; orchestrator's background poll picks this up.
|
||||
NEW_STATE="pending_evaluation"
|
||||
ACTION="complete" # run-loop logs it but takes no action
|
||||
;;
|
||||
waiting_approval)
|
||||
NEW_STATE="waiting_approval"
|
||||
ACTION="approve"
|
||||
;;
|
||||
idle)
|
||||
# Agent process has exited — needs restart
|
||||
NEW_STATE="idle"
|
||||
ACTION="kick"
|
||||
REASON="agent exited (shell is foreground)"
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
fi
|
||||
;;
|
||||
running)
|
||||
# Clear idle_since only when transitioning from idle (agent was kicked and
|
||||
# restarted). Do NOT reset for stuck — idle_since must persist across polls
|
||||
# so STUCK_DURATION can accumulate and trigger escalation.
|
||||
# Also update the local IDLE_SINCE so the hash-stability check below uses
|
||||
# the reset value on this same poll, not the stale kick timestamp.
|
||||
if [[ "$STATE" == "idle" ]]; then
|
||||
NEW_IDLE_SINCE=0
|
||||
IDLE_SINCE=0
|
||||
fi
|
||||
# Check if hash has been stable (agent may be stuck mid-task)
|
||||
if [ -n "$CURRENT_HASH" ] && [ "$CURRENT_HASH" = "$LAST_HASH" ] && [ "$LAST_HASH" != "" ]; then
|
||||
if [ "$IDLE_SINCE" = "0" ] || [ "$IDLE_SINCE" = "null" ]; then
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
else
|
||||
STUCK_DURATION=$(( NOW - IDLE_SINCE ))
|
||||
if [ "$STUCK_DURATION" -gt "$IDLE_THRESHOLD" ]; then
|
||||
NEW_REVISION_COUNT=$(( REVISION_COUNT + 1 ))
|
||||
NEW_IDLE_SINCE=$NOW
|
||||
if [ "$NEW_REVISION_COUNT" -ge 3 ]; then
|
||||
NEW_STATE="escalated"
|
||||
ACTION="none"
|
||||
REASON="escalated after ${NEW_REVISION_COUNT} kicks — needs human attention"
|
||||
else
|
||||
NEW_STATE="stuck"
|
||||
ACTION="kick"
|
||||
REASON="output unchanged for ${STUCK_DURATION}s (threshold: ${IDLE_THRESHOLD}s)"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
# Only reset the idle timer when we have a valid hash comparison (pane
|
||||
# capture succeeded). If CURRENT_HASH is empty (tmux capture-pane failed),
|
||||
# preserve existing timers so stuck detection is not inadvertently reset.
|
||||
if [ -n "$CURRENT_HASH" ]; then
|
||||
NEW_STATE="running"
|
||||
NEW_IDLE_SINCE=0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
error)
|
||||
REASON="classify error: $PANE_REASON"
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build updated agent record (ensure idle_since and revision_count are numeric)
|
||||
# Use || true on each jq call so a malformed field skips this agent rather than
|
||||
# aborting the entire poll cycle under set -e.
|
||||
UPDATED_AGENT=$(echo "$agent" | jq \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg hash "$CURRENT_HASH" \
|
||||
--argjson now "$NOW" \
|
||||
--arg idle_since "$NEW_IDLE_SINCE" \
|
||||
--arg revision_count "$NEW_REVISION_COUNT" \
|
||||
--argjson checkpoints "$NEW_CHECKPOINTS_JSON" \
|
||||
'.state = $state
|
||||
| .last_output_hash = (if $hash == "" then .last_output_hash else $hash end)
|
||||
| .last_seen_at = $now
|
||||
| .idle_since = ($idle_since | tonumber)
|
||||
| .revision_count = ($revision_count | tonumber)
|
||||
| .checkpoints = $checkpoints' 2>/dev/null) || {
|
||||
echo "Warning: failed to build updated agent for window $WINDOW — keeping original" >&2
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$agent" '. + [$a]')
|
||||
continue
|
||||
}
|
||||
|
||||
UPDATED_AGENTS=$(echo "$UPDATED_AGENTS" | jq --argjson a "$UPDATED_AGENT" '. + [$a]')
|
||||
|
||||
# Add action if needed
|
||||
if [ "$ACTION" != "none" ]; then
|
||||
ACTION_OBJ=$(jq -n \
|
||||
--arg window "$WINDOW" \
|
||||
--arg action "$ACTION" \
|
||||
--arg state "$NEW_STATE" \
|
||||
--arg worktree "$WORKTREE" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg reason "$REASON" \
|
||||
'{window:$window, action:$action, state:$state, worktree:$worktree, objective:$objective, reason:$reason}')
|
||||
ACTIONS=$(echo "$ACTIONS" | jq --argjson a "$ACTION_OBJ" '. + [$a]')
|
||||
fi
|
||||
|
||||
done <<< "$AGENTS_JSON"
|
||||
|
||||
# Atomic state file update
|
||||
jq --argjson agents "$UPDATED_AGENTS" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents = $agents | .last_poll_at = $now' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
|
||||
echo "$ACTIONS"
|
||||
32
.claude/skills/orchestrate/scripts/recycle-agent.sh
Executable file
32
.claude/skills/orchestrate/scripts/recycle-agent.sh
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
# recycle-agent.sh — kill a tmux window and restore the worktree to its spare branch
|
||||
#
|
||||
# Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH
|
||||
# WINDOW — tmux target, e.g. autogpt1:3
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — branch to restore, e.g. spare/6
|
||||
#
|
||||
# Stdout: one status line
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: recycle-agent.sh WINDOW WORKTREE_PATH SPARE_BRANCH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WINDOW="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
|
||||
# Kill the tmux window (ignore error — may already be gone)
|
||||
tmux kill-window -t "$WINDOW" 2>/dev/null || true
|
||||
|
||||
# Restore to spare branch: abort any in-progress operation, then clean
|
||||
git -C "$WORKTREE_PATH" rebase --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" merge --abort 2>/dev/null || true
|
||||
git -C "$WORKTREE_PATH" reset --hard HEAD 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" clean -fd 2>/dev/null
|
||||
git -C "$WORKTREE_PATH" checkout "$SPARE_BRANCH"
|
||||
|
||||
echo "Recycled: $(basename "$WORKTREE_PATH") → $SPARE_BRANCH (window $WINDOW closed)"
|
||||
164
.claude/skills/orchestrate/scripts/run-loop.sh
Executable file
164
.claude/skills/orchestrate/scripts/run-loop.sh
Executable file
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env bash
|
||||
# run-loop.sh — Mechanical babysitter for the agent fleet (runs in its own tmux window)
|
||||
#
|
||||
# Handles ONLY two things that need no intelligence:
|
||||
# idle → restart claude using --resume SESSION_ID (or --continue) to restore context
|
||||
# approve → auto-approve safe dialogs, press Enter on numbered-option dialogs
|
||||
#
|
||||
# Everything else — ORCHESTRATOR:DONE, verification, /pr-test, final evaluation,
|
||||
# marking done, deciding to close windows — is the orchestrating Claude's job.
|
||||
# poll-cycle.sh sets state to pending_evaluation when ORCHESTRATOR:DONE is detected;
|
||||
# the orchestrator's background poll loop handles it from there.
|
||||
#
|
||||
# Usage: run-loop.sh
|
||||
# Env: POLL_INTERVAL (default: 30), ORCHESTRATOR_STATE_FILE
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Copy scripts to a stable location outside the repo so they survive branch
|
||||
# checkouts (e.g. recycle-agent.sh switching spare/N back into this worktree
|
||||
# would wipe .claude/skills/orchestrate/scripts if the skill only exists on the
|
||||
# current branch).
|
||||
_ORIGIN_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
STABLE_SCRIPTS_DIR="$HOME/.claude/orchestrator/scripts"
|
||||
mkdir -p "$STABLE_SCRIPTS_DIR"
|
||||
cp "$_ORIGIN_DIR"/*.sh "$STABLE_SCRIPTS_DIR/"
|
||||
chmod +x "$STABLE_SCRIPTS_DIR"/*.sh
|
||||
SCRIPTS_DIR="$STABLE_SCRIPTS_DIR"
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
POLL_INTERVAL="${POLL_INTERVAL:-30}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_state WINDOW FIELD VALUE
|
||||
# ---------------------------------------------------------------------------
|
||||
update_state() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --arg v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
update_state_int() {
|
||||
local window="$1" field="$2" value="$3"
|
||||
jq --arg w "$window" --arg f "$field" --argjson v "$value" \
|
||||
'.agents |= map(if .window == $w then .[$f] = $v else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
}
|
||||
|
||||
agent_field() {
|
||||
jq -r --arg w "$1" --arg f "$2" \
|
||||
'.agents[] | select(.window == $w) | .[$f] // ""' \
|
||||
"$STATE_FILE" 2>/dev/null
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wait_for_prompt WINDOW — wait up to 60s for Claude's ❯ prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
wait_for_prompt() {
|
||||
local window="$1"
|
||||
for i in $(seq 1 60); do
|
||||
local cmd pane
|
||||
cmd=$(tmux display-message -t "$window" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
pane=$(tmux capture-pane -t "$window" -p 2>/dev/null || echo "")
|
||||
if echo "$pane" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$window" Down Enter; sleep 2; continue
|
||||
fi
|
||||
[[ "$cmd" == "node" ]] && echo "$pane" | grep -q "❯" && return 0
|
||||
sleep 1
|
||||
done
|
||||
return 1 # timed out
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_kick WINDOW STATE — only for idle (crashed) agents, not stuck
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_kick() {
|
||||
local window="$1" state="$2"
|
||||
[[ "$state" != "idle" ]] && return # stuck agents handled by supervisor
|
||||
|
||||
local worktree_path session_id
|
||||
worktree_path=$(agent_field "$window" "worktree_path")
|
||||
session_id=$(agent_field "$window" "session_id")
|
||||
|
||||
echo "[$(date +%H:%M:%S)] KICK restart $window — agent exited, resuming session"
|
||||
|
||||
# Resume the exact session so the agent retains full context — no need to re-send objective
|
||||
if [ -n "$session_id" ]; then
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --resume '${session_id}' --permission-mode bypassPermissions" Enter
|
||||
else
|
||||
tmux send-keys -t "$window" "cd '${worktree_path}' && claude --continue --permission-mode bypassPermissions" Enter
|
||||
fi
|
||||
|
||||
wait_for_prompt "$window" || echo "[$(date +%H:%M:%S)] KICK WARNING $window — timed out waiting for ❯"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_approve WINDOW — auto-approve dialogs that need no judgment
|
||||
# ---------------------------------------------------------------------------
|
||||
handle_approve() {
|
||||
local window="$1"
|
||||
local pane_tail
|
||||
pane_tail=$(tmux capture-pane -t "$window" -p 2>/dev/null | tail -3 || echo "")
|
||||
|
||||
# Settings error dialog at startup
|
||||
if echo "$pane_tail" | grep -q "Enter to confirm"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE dialog $window — settings error"
|
||||
tmux send-keys -t "$window" Down Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Numbered-option dialog (e.g. "Do you want to make this edit?")
|
||||
# ❯ is already on option 1 (Yes) — Enter confirms it
|
||||
if echo "$pane_tail" | grep -qE "❯\s*1\." || echo "$pane_tail" | grep -q "Esc to cancel"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE edit $window"
|
||||
tmux send-keys -t "$window" "" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# y/n prompt for safe operations
|
||||
if echo "$pane_tail" | grep -qiE "(^git |^npm |^pnpm |^poetry |^pytest|^docker |^make |^cargo |^pip |^yarn |curl .*(localhost|127\.0\.0\.1))"; then
|
||||
echo "[$(date +%H:%M:%S)] APPROVE safe $window"
|
||||
tmux send-keys -t "$window" "y" Enter
|
||||
return
|
||||
fi
|
||||
|
||||
# Anything else — supervisor handles it, just log
|
||||
echo "[$(date +%H:%M:%S)] APPROVE skip $window — unknown dialog, supervisor will handle"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main loop
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "[$(date +%H:%M:%S)] run-loop started (mechanical only, poll every ${POLL_INTERVAL}s)"
|
||||
echo "[$(date +%H:%M:%S)] Supervisor: orchestrating Claude session (not a separate window)"
|
||||
echo "---"
|
||||
|
||||
while true; do
|
||||
if ! jq -e '.active == true' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "[$(date +%H:%M:%S)] active=false — exiting."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
ACTIONS=$("$SCRIPTS_DIR/poll-cycle.sh" 2>/dev/null || echo "[]")
|
||||
KICKED=0; DONE=0
|
||||
|
||||
while IFS= read -r action; do
|
||||
[ -z "$action" ] && continue
|
||||
WINDOW=$(echo "$action" | jq -r '.window // ""')
|
||||
ACTION=$(echo "$action" | jq -r '.action // ""')
|
||||
STATE=$(echo "$action" | jq -r '.state // ""')
|
||||
|
||||
case "$ACTION" in
|
||||
kick) handle_kick "$WINDOW" "$STATE" || true; KICKED=$(( KICKED + 1 )) ;;
|
||||
approve) handle_approve "$WINDOW" || true ;;
|
||||
complete) DONE=$(( DONE + 1 )) ;; # poll-cycle already set state=pending_evaluation; orchestrator handles
|
||||
esac
|
||||
done < <(echo "$ACTIONS" | jq -c '.[]' 2>/dev/null || true)
|
||||
|
||||
RUNNING=$(jq '[.agents[] | select(.state | test("running|stuck|waiting_approval|idle"))] | length' \
|
||||
"$STATE_FILE" 2>/dev/null || echo 0)
|
||||
|
||||
echo "[$(date +%H:%M:%S)] Poll — ${RUNNING} running ${KICKED} kicked ${DONE} recycled"
|
||||
sleep "$POLL_INTERVAL"
|
||||
done
|
||||
122
.claude/skills/orchestrate/scripts/spawn-agent.sh
Executable file
122
.claude/skills/orchestrate/scripts/spawn-agent.sh
Executable file
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env bash
|
||||
# spawn-agent.sh — create tmux window, checkout branch, launch claude, send task
|
||||
#
|
||||
# Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]
|
||||
# SESSION — tmux session name, e.g. autogpt1
|
||||
# WORKTREE_PATH — absolute path to the git worktree
|
||||
# SPARE_BRANCH — spare branch being replaced, e.g. spare/6 (saved for recycle)
|
||||
# NEW_BRANCH — task branch to create, e.g. feat/my-feature
|
||||
# OBJECTIVE — task description sent to the agent
|
||||
# PR_NUMBER — (optional) GitHub PR number for completion verification
|
||||
# STEPS... — (optional) required checkpoint names, e.g. pr-address pr-test
|
||||
#
|
||||
# Stdout: SESSION:WINDOW_INDEX (nothing else — callers rely on this)
|
||||
# Exit non-zero on failure.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ $# -lt 5 ]; then
|
||||
echo "Usage: spawn-agent.sh SESSION WORKTREE_PATH SPARE_BRANCH NEW_BRANCH OBJECTIVE [PR_NUMBER] [STEPS...]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SESSION="$1"
|
||||
WORKTREE_PATH="$2"
|
||||
SPARE_BRANCH="$3"
|
||||
NEW_BRANCH="$4"
|
||||
OBJECTIVE="$5"
|
||||
PR_NUMBER="${6:-}"
|
||||
STEPS=("${@:7}")
|
||||
WORKTREE_NAME=$(basename "$WORKTREE_PATH")
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
# Generate a stable session ID so this agent's Claude session can always be resumed:
|
||||
# claude --resume $SESSION_ID --permission-mode bypassPermissions
|
||||
SESSION_ID=$(uuidgen 2>/dev/null || python3 -c "import uuid; print(uuid.uuid4())")
|
||||
|
||||
# Create (or switch to) the task branch
|
||||
git -C "$WORKTREE_PATH" checkout -b "$NEW_BRANCH" 2>/dev/null \
|
||||
|| git -C "$WORKTREE_PATH" checkout "$NEW_BRANCH"
|
||||
|
||||
# Open a new named tmux window; capture its numeric index
|
||||
WIN_IDX=$(tmux new-window -t "$SESSION" -n "$WORKTREE_NAME" -P -F '#{window_index}')
|
||||
WINDOW="${SESSION}:${WIN_IDX}"
|
||||
|
||||
# Append the initial agent record to the state file so subsequent jq updates find it.
|
||||
# This must happen before the pr_number/steps update below.
|
||||
if [ -f "$STATE_FILE" ]; then
|
||||
NOW=$(date +%s)
|
||||
jq --arg window "$WINDOW" \
|
||||
--arg worktree "$WORKTREE_NAME" \
|
||||
--arg worktree_path "$WORKTREE_PATH" \
|
||||
--arg spare_branch "$SPARE_BRANCH" \
|
||||
--arg branch "$NEW_BRANCH" \
|
||||
--arg objective "$OBJECTIVE" \
|
||||
--arg session_id "$SESSION_ID" \
|
||||
--argjson now "$NOW" \
|
||||
'.agents += [{
|
||||
"window": $window,
|
||||
"worktree": $worktree,
|
||||
"worktree_path": $worktree_path,
|
||||
"spare_branch": $spare_branch,
|
||||
"branch": $branch,
|
||||
"objective": $objective,
|
||||
"session_id": $session_id,
|
||||
"state": "running",
|
||||
"checkpoints": [],
|
||||
"last_output_hash": "",
|
||||
"last_seen_at": $now,
|
||||
"spawned_at": $now,
|
||||
"idle_since": 0,
|
||||
"revision_count": 0,
|
||||
"last_rebriefed_at": 0
|
||||
}]' "$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Store pr_number + steps in state file if provided (enables verify-complete.sh).
|
||||
# The agent record was appended above so the jq select now finds it.
|
||||
if [ -n "$PR_NUMBER" ] && [ -f "$STATE_FILE" ]; then
|
||||
if [ "${#STEPS[@]}" -gt 0 ]; then
|
||||
STEPS_JSON=$(printf '%s\n' "${STEPS[@]}" | jq -R . | jq -s .)
|
||||
else
|
||||
STEPS_JSON='[]'
|
||||
fi
|
||||
jq --arg w "$WINDOW" --arg pr "$PR_NUMBER" --argjson steps "$STEPS_JSON" \
|
||||
'.agents |= map(if .window == $w then . + {pr_number: $pr, steps: $steps, checkpoints: []} else . end)' \
|
||||
"$STATE_FILE" > "${STATE_FILE}.tmp" && mv "${STATE_FILE}.tmp" "$STATE_FILE"
|
||||
fi
|
||||
|
||||
# Launch claude with a stable session ID so it can always be resumed after a crash:
|
||||
# claude --resume SESSION_ID --permission-mode bypassPermissions
|
||||
tmux send-keys -t "$WINDOW" "cd '${WORKTREE_PATH}' && claude --permission-mode bypassPermissions --session-id '${SESSION_ID}'" Enter
|
||||
|
||||
# Wait up to 60s for claude to be fully interactive:
|
||||
# both pane_current_command == 'node' AND the '❯' prompt is visible.
|
||||
PROMPT_FOUND=false
|
||||
for i in $(seq 1 60); do
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "")
|
||||
PANE=$(tmux capture-pane -t "$WINDOW" -p 2>/dev/null || echo "")
|
||||
if echo "$PANE" | grep -q "Enter to confirm"; then
|
||||
tmux send-keys -t "$WINDOW" Down Enter
|
||||
sleep 2
|
||||
continue
|
||||
fi
|
||||
if [[ "$CMD" == "node" ]] && echo "$PANE" | grep -q "❯"; then
|
||||
PROMPT_FOUND=true
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
if ! $PROMPT_FOUND; then
|
||||
echo "[spawn-agent] WARNING: timed out waiting for ❯ prompt on $WINDOW — sending objective anyway" >&2
|
||||
fi
|
||||
|
||||
# Send the task. Split text and Enter — if combined, Enter can fire before the string
|
||||
# is fully buffered, leaving the message stuck as "[Pasted text +N lines]" unsent.
|
||||
tmux send-keys -t "$WINDOW" "${OBJECTIVE} Output each completed step as CHECKPOINT:<step-name>. When ALL steps are done, output ORCHESTRATOR:DONE on its own line."
|
||||
sleep 0.3
|
||||
tmux send-keys -t "$WINDOW" Enter
|
||||
|
||||
# Only output the window address — nothing else (callers parse this)
|
||||
echo "$WINDOW"
|
||||
43
.claude/skills/orchestrate/scripts/status.sh
Executable file
43
.claude/skills/orchestrate/scripts/status.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# status.sh — print orchestrator status: state file summary + live tmux pane commands
|
||||
#
|
||||
# Usage: status.sh
|
||||
# Reads: ~/.claude/orchestrator-state.json
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
if [ ! -f "$STATE_FILE" ] || ! jq -e '.' "$STATE_FILE" >/dev/null 2>&1; then
|
||||
echo "No orchestrator state found at $STATE_FILE"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Header: active status, session, thresholds, last poll
|
||||
jq -r '
|
||||
"=== Orchestrator [\(if .active then "RUNNING" else "STOPPED" end)] ===",
|
||||
"Session: \(.tmux_session // "unknown") | Idle threshold: \(.idle_threshold_seconds // 300)s",
|
||||
"Last poll: \(if (.last_poll_at // 0) == 0 then "never" else (.last_poll_at | strftime("%H:%M:%S")) end)",
|
||||
""
|
||||
' "$STATE_FILE"
|
||||
|
||||
# Each agent: state, window, worktree/branch, truncated objective
|
||||
AGENT_COUNT=$(jq '.agents | length' "$STATE_FILE")
|
||||
if [ "$AGENT_COUNT" -eq 0 ]; then
|
||||
echo " (no agents registered)"
|
||||
else
|
||||
jq -r '
|
||||
.agents[] |
|
||||
" [\(.state | ascii_upcase)] \(.window) \(.worktree)/\(.branch)",
|
||||
" \(.objective // "" | .[0:70])"
|
||||
' "$STATE_FILE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# Live pane_current_command for non-done agents
|
||||
while IFS= read -r WINDOW; do
|
||||
[ -z "$WINDOW" ] && continue
|
||||
CMD=$(tmux display-message -t "$WINDOW" -p '#{pane_current_command}' 2>/dev/null || echo "unreachable")
|
||||
echo " $WINDOW live: $CMD"
|
||||
done < <(jq -r '.agents[] | select(.state != "done") | .window' "$STATE_FILE" 2>/dev/null || true)
|
||||
180
.claude/skills/orchestrate/scripts/verify-complete.sh
Normal file
180
.claude/skills/orchestrate/scripts/verify-complete.sh
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env bash
|
||||
# verify-complete.sh — verify a PR task is truly done before marking the agent done
|
||||
#
|
||||
# Check order matters:
|
||||
# 1. Checkpoints — did the agent do all required steps?
|
||||
# 2. CI complete — no pending (bots post comments AFTER their check runs, must wait)
|
||||
# 3. CI passing — no failures (agent must fix before done)
|
||||
# 4. spawned_at — a new CI run was triggered after agent spawned (proves real work)
|
||||
# 5. Unresolved threads — checked AFTER CI so bot-posted comments are included
|
||||
# 6. CHANGES_REQUESTED — checked AFTER CI so bot reviews are included
|
||||
#
|
||||
# Usage: verify-complete.sh WINDOW
|
||||
# Exit 0 = verified complete; exit 1 = not complete (stderr has reason)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
WINDOW="$1"
|
||||
STATE_FILE="${ORCHESTRATOR_STATE_FILE:-$HOME/.claude/orchestrator-state.json}"
|
||||
|
||||
PR_NUMBER=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .pr_number // ""' "$STATE_FILE" 2>/dev/null)
|
||||
STEPS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .steps // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
CHECKPOINTS=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .checkpoints // [] | .[]' "$STATE_FILE" 2>/dev/null || true)
|
||||
WORKTREE_PATH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .worktree_path // ""' "$STATE_FILE" 2>/dev/null)
|
||||
BRANCH=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .branch // ""' "$STATE_FILE" 2>/dev/null)
|
||||
SPAWNED_AT=$(jq -r --arg w "$WINDOW" '.agents[] | select(.window == $w) | .spawned_at // "0"' "$STATE_FILE" 2>/dev/null || echo "0")
|
||||
|
||||
# No PR number = cannot verify
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "NOT COMPLETE: no pr_number in state — set pr_number or mark done manually" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 1: all required steps are checkpointed ---
|
||||
MISSING=""
|
||||
while IFS= read -r step; do
|
||||
[ -z "$step" ] && continue
|
||||
if ! echo "$CHECKPOINTS" | grep -qFx "$step"; then
|
||||
MISSING="$MISSING $step"
|
||||
fi
|
||||
done <<< "$STEPS"
|
||||
|
||||
if [ -n "$MISSING" ]; then
|
||||
echo "NOT COMPLETE: missing checkpoints:$MISSING on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve repo for all GitHub checks below
|
||||
REPO=$(jq -r '.repo // ""' "$STATE_FILE" 2>/dev/null || echo "")
|
||||
if [ -z "$REPO" ] && [ -n "$WORKTREE_PATH" ] && [ -d "$WORKTREE_PATH" ]; then
|
||||
REPO=$(git -C "$WORKTREE_PATH" remote get-url origin 2>/dev/null \
|
||||
| sed 's|.*github\.com[:/]||; s|\.git$||' || echo "")
|
||||
fi
|
||||
|
||||
if [ -z "$REPO" ]; then
|
||||
echo "Warning: cannot resolve repo — skipping CI/thread checks" >&2
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓ (CI/thread checks skipped — no repo)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
CI_BUCKETS=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket 2>/dev/null || echo "[]")
|
||||
|
||||
# --- Check 2: CI fully complete — no pending checks ---
|
||||
# Pending checks MUST finish before we check threads/reviews:
|
||||
# bots (Seer, Check PR Status, etc.) post comments and CHANGES_REQUESTED AFTER their CI check runs.
|
||||
PENDING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "pending")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$PENDING" -gt 0 ]; then
|
||||
PENDING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "pending") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $PENDING CI checks still pending on PR #$PR_NUMBER ($PENDING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 3: CI passing — no failures ---
|
||||
FAILING=$(echo "$CI_BUCKETS" | jq '[.[] | select(.bucket == "fail")] | length' 2>/dev/null || echo "0")
|
||||
if [ "$FAILING" -gt 0 ]; then
|
||||
FAILING_NAMES=$(gh pr checks "$PR_NUMBER" --repo "$REPO" --json bucket,name 2>/dev/null \
|
||||
| jq -r '[.[] | select(.bucket == "fail") | .name] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
echo "NOT COMPLETE: $FAILING failing CI checks on PR #$PR_NUMBER ($FAILING_NAMES)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 4: a new CI run was triggered AFTER the agent spawned ---
|
||||
if [ -n "$BRANCH" ] && [ "${SPAWNED_AT:-0}" -gt 0 ]; then
|
||||
LATEST_RUN_AT=$(gh run list --repo "$REPO" --branch "$BRANCH" \
|
||||
--json createdAt --limit 1 2>/dev/null | jq -r '.[0].createdAt // ""')
|
||||
if [ -n "$LATEST_RUN_AT" ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_RUN_EPOCH=$(date -d "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_RUN_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_RUN_AT" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$LATEST_RUN_EPOCH" -le "$SPAWNED_AT" ]; then
|
||||
echo "NOT COMPLETE: latest CI run on $BRANCH predates agent spawn — agent may not have pushed yet" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
OWNER=$(echo "$REPO" | cut -d/ -f1)
|
||||
REPONAME=$(echo "$REPO" | cut -d/ -f2)
|
||||
|
||||
# --- Check 5: no unresolved review threads (checked AFTER CI — bots post after their check) ---
|
||||
UNRESOLVED=$(gh api graphql -f query="
|
||||
{ repository(owner: \"${OWNER}\", name: \"${REPONAME}\") {
|
||||
pullRequest(number: ${PR_NUMBER}) {
|
||||
reviewThreads(first: 50) { nodes { isResolved } }
|
||||
}
|
||||
}
|
||||
}
|
||||
" --jq '[.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false)] | length' 2>/dev/null || echo "0")
|
||||
|
||||
if [ "$UNRESOLVED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: $UNRESOLVED unresolved review threads on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Check 6: no CHANGES_REQUESTED (checked AFTER CI — bots post reviews after their check) ---
|
||||
# A CHANGES_REQUESTED review is stale if the latest commit was pushed AFTER the review was submitted.
|
||||
# Stale reviews (pre-dating the fixing commits) should not block verification.
|
||||
#
|
||||
# Fetch commits and latestReviews in a single call and fail closed — if gh fails,
|
||||
# treat that as NOT COMPLETE rather than silently passing.
|
||||
# Use latestReviews (not reviews) so each reviewer's latest state is used — superseded
|
||||
# CHANGES_REQUESTED entries are automatically excluded when the reviewer later approved.
|
||||
# Note: we intentionally use committedDate (not PR updatedAt) because updatedAt changes on any
|
||||
# PR activity (bot comments, label changes) which would create false negatives.
|
||||
PR_REVIEW_METADATA=$(gh pr view "$PR_NUMBER" --repo "$REPO" \
|
||||
--json commits,latestReviews 2>/dev/null) || {
|
||||
echo "NOT COMPLETE: unable to fetch PR review metadata for PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
LATEST_COMMIT_DATE=$(jq -r '.commits[-1].committedDate // ""' <<< "$PR_REVIEW_METADATA")
|
||||
CHANGES_REQUESTED_REVIEWS=$(jq '[.latestReviews[]? | select(.state == "CHANGES_REQUESTED")]' <<< "$PR_REVIEW_METADATA")
|
||||
|
||||
BLOCKING_CHANGES_REQUESTED=0
|
||||
BLOCKING_REQUESTERS=""
|
||||
|
||||
if [ -n "$LATEST_COMMIT_DATE" ] && [ "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length)" -gt 0 ]; then
|
||||
if date --version >/dev/null 2>&1; then
|
||||
LATEST_COMMIT_EPOCH=$(date -d "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
LATEST_COMMIT_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$LATEST_COMMIT_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
|
||||
while IFS= read -r review; do
|
||||
[ -z "$review" ] && continue
|
||||
REVIEW_DATE=$(echo "$review" | jq -r '.submittedAt // ""')
|
||||
REVIEWER=$(echo "$review" | jq -r '.author.login // "unknown"')
|
||||
if [ -z "$REVIEW_DATE" ]; then
|
||||
# No submission date — treat as fresh (conservative: blocks verification)
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
else
|
||||
if date --version >/dev/null 2>&1; then
|
||||
REVIEW_EPOCH=$(date -d "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
else
|
||||
REVIEW_EPOCH=$(TZ=UTC date -j -f "%Y-%m-%dT%H:%M:%SZ" "$REVIEW_DATE" "+%s" 2>/dev/null || echo "0")
|
||||
fi
|
||||
if [ "$REVIEW_EPOCH" -gt "$LATEST_COMMIT_EPOCH" ]; then
|
||||
# Review was submitted AFTER latest commit — still fresh, blocks verification
|
||||
BLOCKING_CHANGES_REQUESTED=$(( BLOCKING_CHANGES_REQUESTED + 1 ))
|
||||
BLOCKING_REQUESTERS="${BLOCKING_REQUESTERS:+$BLOCKING_REQUESTERS, }${REVIEWER}"
|
||||
fi
|
||||
# Review submitted BEFORE latest commit — stale, skip
|
||||
fi
|
||||
done <<< "$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -c '.[]')"
|
||||
else
|
||||
# No commit date or no changes_requested — check raw count as fallback
|
||||
BLOCKING_CHANGES_REQUESTED=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq length 2>/dev/null || echo "0")
|
||||
BLOCKING_REQUESTERS=$(echo "$CHANGES_REQUESTED_REVIEWS" | jq -r '[.[].author.login] | join(", ")' 2>/dev/null || echo "unknown")
|
||||
fi
|
||||
|
||||
if [ "$BLOCKING_CHANGES_REQUESTED" -gt 0 ]; then
|
||||
echo "NOT COMPLETE: CHANGES_REQUESTED (after latest commit) from ${BLOCKING_REQUESTERS} on PR #$PR_NUMBER" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "VERIFIED: PR #$PR_NUMBER — checkpoints ✓, CI complete + green, 0 unresolved threads, no CHANGES_REQUESTED"
|
||||
exit 0
|
||||
@@ -17,6 +17,14 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
@@ -82,10 +90,34 @@ Address comments **one at a time**: fix → commit → push → inline reply →
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
Use a **markdown commit link** so GitHub renders it as a clickable reference. Get the full SHA with `git rev-parse HEAD` after committing:
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in [abc1234](https://github.com/Significant-Gravitas/AutoGPT/commit/FULL_SHA): <description>"` |
|
||||
|
||||
## Codecov coverage
|
||||
|
||||
Codecov patch target is **80%** on changed lines. Checks are **informational** (not blocking) but should be green.
|
||||
|
||||
### Running coverage locally
|
||||
|
||||
**Backend** (from `autogpt_platform/backend/`):
|
||||
```bash
|
||||
poetry run pytest -s -vv --cov=backend --cov-branch --cov-report term-missing
|
||||
```
|
||||
|
||||
**Frontend** (from `autogpt_platform/frontend/`):
|
||||
```bash
|
||||
pnpm vitest run --coverage
|
||||
```
|
||||
|
||||
### When codecov/patch fails
|
||||
|
||||
1. Find uncovered files: `git diff --name-only $(gh pr view --json baseRefName --jq '.baseRefName')...HEAD`
|
||||
2. For each uncovered file — extract inline logic to `helpers.ts`/`helpers.py` and test those (highest ROI). Colocate tests as `*_test.py` (backend) or `__tests__/*.test.ts` (frontend).
|
||||
3. Run coverage locally to verify, commit, push.
|
||||
|
||||
## Format and commit
|
||||
|
||||
@@ -105,7 +137,9 @@ kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
|
||||
|
||||
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
|
||||
@@ -17,6 +17,16 @@ gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoG
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
@@ -34,6 +44,8 @@ gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
@@ -5,13 +5,96 @@ user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Critical Requirements
|
||||
|
||||
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
|
||||
|
||||
### 1. Screenshots at Every Step
|
||||
- Take a screenshot at EVERY significant test step — not just at the end
|
||||
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
|
||||
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
|
||||
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
|
||||
|
||||
### 2. Screenshots MUST Be Posted to PR
|
||||
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
|
||||
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
|
||||
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
|
||||
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
|
||||
|
||||
### 3. State Verification with Before/After Evidence
|
||||
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
|
||||
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
|
||||
- Screenshot MUST show the relevant UI state change
|
||||
- Compare expected vs actual values explicitly — do not just eyeball it
|
||||
|
||||
### 4. Negative Test Cases Are Mandatory
|
||||
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
|
||||
- Verify error messages are user-friendly and accurate
|
||||
- Verify the system state did NOT change after a rejected operation
|
||||
|
||||
### 5. Test Report Must Include Full Evidence
|
||||
Each test scenario in the report MUST have:
|
||||
- **Steps**: What was done (exact commands or UI actions)
|
||||
- **Expected**: What should happen
|
||||
- **Actual**: What actually happened
|
||||
- **API Evidence**: Before/after API response values for state-changing operations
|
||||
- **Screenshot Evidence**: Before/after screenshots with explanations
|
||||
|
||||
## State Manipulation for Realistic Testing
|
||||
|
||||
When testing features that depend on specific states (rate limits, credits, quotas):
|
||||
|
||||
1. **Use Redis CLI to set counters directly:**
|
||||
```bash
|
||||
# Find the Redis container
|
||||
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
|
||||
# Set a key with expiry
|
||||
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
|
||||
# Example: Set rate limit counter to near-limit
|
||||
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
|
||||
# Example: Check current value
|
||||
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
|
||||
```
|
||||
|
||||
2. **Use API calls to check before/after state:**
|
||||
```bash
|
||||
# BEFORE: Record current state
|
||||
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits BEFORE: $BEFORE"
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER: Record new state and compare
|
||||
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits AFTER: $AFTER"
|
||||
echo "Delta: $(( BEFORE - AFTER ))"
|
||||
```
|
||||
|
||||
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
|
||||
|
||||
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
|
||||
|
||||
5. **Use direct DB queries when needed:**
|
||||
```bash
|
||||
# Query via Supabase's PostgREST or docker exec into the DB
|
||||
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
|
||||
```
|
||||
|
||||
6. **After every API test, verify the state change actually persisted:**
|
||||
```bash
|
||||
# Example: After a credits purchase, verify DB matches API
|
||||
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
|
||||
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
@@ -53,14 +136,20 @@ Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
|
||||
# Read PR description to understand the WHY
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the changed files to understand:
|
||||
1. What feature/fix does this PR implement?
|
||||
2. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
3. What are the key user-facing behaviors to test?
|
||||
Read the PR description (Why / What / How) and changed files to understand:
|
||||
0. **Why** does this PR exist? What problem does it solve?
|
||||
1. **What** feature/fix does this PR implement?
|
||||
2. **How** does it work? What's the approach?
|
||||
3. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
4. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
@@ -75,15 +164,21 @@ Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
- Before state: [what to check before]
|
||||
- After state: [what to verify changed]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
- Screenshot before: [what to capture]
|
||||
- Screenshot after: [what to capture]
|
||||
|
||||
## Negative Tests
|
||||
1. [What should NOT happen]
|
||||
## Negative Tests (REQUIRED — at least one per feature)
|
||||
1. [What should NOT happen] — [how to trigger it]
|
||||
- Expected error: [what error message/code]
|
||||
- State unchanged: [what to verify did NOT change]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks.
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
@@ -233,7 +328,7 @@ curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests:
|
||||
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
@@ -256,6 +351,27 @@ curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
**State verification pattern (use for EVERY state-changing API call):**
|
||||
```bash
|
||||
# 1. Record BEFORE state
|
||||
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "BEFORE: $BEFORE_STATE"
|
||||
|
||||
# 2. Perform the action
|
||||
ACTION_RESULT=$(curl -s -X POST ... | jq .)
|
||||
echo "ACTION RESULT: $ACTION_RESULT"
|
||||
|
||||
# 3. Record AFTER state
|
||||
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "AFTER: $AFTER_STATE"
|
||||
|
||||
# 4. Log the comparison
|
||||
echo "=== STATE CHANGE VERIFICATION ==="
|
||||
echo "Before: $BEFORE_STATE"
|
||||
echo "After: $AFTER_STATE"
|
||||
echo "Expected change: {describe what should have changed}"
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
@@ -348,19 +464,34 @@ agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --time
|
||||
|
||||
## Step 5: Record results and take screenshots
|
||||
|
||||
**Take a screenshot at every significant test step** — before and after interactions, on success, and on failure. Name them sequentially with descriptive names:
|
||||
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
|
||||
|
||||
**Required screenshot pattern for each test scenario:**
|
||||
```bash
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{description}.png
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page.png
|
||||
# $RESULTS_DIR/02-builder-with-block.png
|
||||
# $RESULTS_DIR/03-copilot-response.png
|
||||
# $RESULTS_DIR/04-agent-execution-result.png
|
||||
# $RESULTS_DIR/05-error-state.png
|
||||
# BEFORE the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
|
||||
```
|
||||
|
||||
**Aim for at least one screenshot per test scenario.** More is better — screenshots are the primary evidence that tests were actually run.
|
||||
**Naming convention:**
|
||||
```bash
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page-before.png
|
||||
# $RESULTS_DIR/02-login-page-after.png
|
||||
# $RESULTS_DIR/03-credits-page-before.png
|
||||
# $RESULTS_DIR/04-credits-purchase-after.png
|
||||
# $RESULTS_DIR/05-negative-insufficient-credits.png
|
||||
# $RESULTS_DIR/06-error-state.png
|
||||
```
|
||||
|
||||
**Minimum requirements:**
|
||||
- At least TWO screenshots per test scenario (before + after)
|
||||
- At least ONE screenshot for each negative test case showing the error state
|
||||
- If a test fails, screenshot the failure state AND any error logs visible in the UI
|
||||
|
||||
## Step 6: Show results to user with screenshots
|
||||
|
||||
@@ -384,26 +515,29 @@ Format the output like this:
|
||||
---
|
||||
```
|
||||
|
||||
After showing all screenshots, output a summary table:
|
||||
After showing all screenshots, output a **detailed** summary table:
|
||||
|
||||
| # | Scenario | Result |
|
||||
|---|----------|--------|
|
||||
| 1 | {name} | PASS/FAIL |
|
||||
| 2 | ... | ... |
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
|
||||
| 2 | ... | ... | ... | ... |
|
||||
|
||||
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
|
||||
|
||||
```bash
|
||||
# Build these variables during Step 6 — they are required by Step 7's script
|
||||
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS |
|
||||
| 2 | Builder block addition | PASS |
|
||||
| 3 | Copilot chat | FAIL |"
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
|
||||
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
|
||||
# ... one row per test scenario with actual results
|
||||
```
|
||||
|
||||
@@ -411,6 +545,10 @@ TEST_RESULTS_TABLE="| 1 | Login flow | PASS |
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
@@ -418,12 +556,29 @@ SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot and build tree JSON
|
||||
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
|
||||
shopt -s nullglob
|
||||
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
|
||||
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
|
||||
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
|
||||
exit 1
|
||||
fi
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
FAILED_UPLOADS=()
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
|
||||
BLOB_SHA=""
|
||||
for attempt in 1 2 3; do
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
|
||||
[ -n "$BLOB_SHA" ] && break
|
||||
sleep 1
|
||||
done
|
||||
if [ -z "$BLOB_SHA" ]; then
|
||||
FAILED_UPLOADS+=("$img")
|
||||
continue
|
||||
fi
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
@@ -431,15 +586,27 @@ TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
|
||||
# Resolve parent commit so screenshots are chained, not orphan root commits
|
||||
PARENT_SHA=$(gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" --jq '.object.sha' 2>/dev/null || echo "")
|
||||
if [ -n "$PARENT_SHA" ]; then
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
-f "parents[]=$PARENT_SHA" \
|
||||
--jq '.sha')
|
||||
else
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
fi
|
||||
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
-X PATCH -f sha="$COMMIT_SHA" -F force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
@@ -447,13 +614,25 @@ Then post the comment with **inline images AND explanations for each screenshot*
|
||||
```bash
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
|
||||
# Build image markdown using SCREENSHOT_EXPLANATIONS and TEST_RESULTS_TABLE from Step 6
|
||||
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
|
||||
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
|
||||
# Skip images that failed to upload — they will be listed at the end
|
||||
IS_FAILED=false
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
|
||||
done
|
||||
if [ "$IS_FAILED" = true ]; then
|
||||
continue
|
||||
fi
|
||||
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
|
||||
if [ -z "$EXPLANATION" ]; then
|
||||
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
|
||||
exit 1
|
||||
fi
|
||||
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
|
||||
### ${TITLE}
|
||||

|
||||
@@ -463,54 +642,197 @@ done
|
||||
|
||||
# Write comment body to file to avoid shell interpretation issues with special characters
|
||||
COMMENT_FILE=$(mktemp)
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## 🧪 E2E Test Report
|
||||
# If any uploads failed, append a section listing them with instructions
|
||||
FAILED_SECTION=""
|
||||
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
|
||||
FAILED_SECTION="
|
||||
## ⚠️ Failed Screenshot Uploads
|
||||
The following screenshots could not be uploaded via the GitHub API after 3 retries.
|
||||
**To add them:** drag-and-drop or paste these files into a PR comment manually:
|
||||
"
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
- \`$(basename "$failed")\` (local path: \`$failed\`)"
|
||||
done
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
|
||||
| # | Scenario | Result |
|
||||
|---|----------|--------|
|
||||
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
|
||||
fi
|
||||
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## E2E Test Report
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
${TEST_RESULTS_TABLE}
|
||||
|
||||
${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
|
||||
# Verify the posted comment contains inline images — exit 1 if none found
|
||||
# Use separate --paginate + jq pipe: --jq applies per-page, not to the full list
|
||||
LAST_COMMENT=$(gh api "repos/${REPO}/issues/$PR_NUMBER/comments" --paginate 2>/dev/null | jq -r '.[-1].body // ""')
|
||||
if ! echo "$LAST_COMMENT" | grep -q '!\['; then
|
||||
echo "ERROR: Posted comment contains no inline images (![). Bare directory links are not acceptable." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ Inline images verified in posted comment"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL
|
||||
2. Every screenshot rendered inline (not just linked)
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Step 8: Evaluate and post a formal PR review
|
||||
|
||||
After the test comment is posted, evaluate whether the run was thorough enough to make a merge decision, then post a formal GitHub review (approve or request changes). **This step is mandatory — every test run MUST end with a formal review decision.**
|
||||
|
||||
### Evaluation criteria
|
||||
|
||||
Re-read the PR description:
|
||||
```bash
|
||||
gh pr view "$PR_NUMBER" --json body --jq '.body' --repo "$REPO"
|
||||
```
|
||||
|
||||
Score the run against each criterion:
|
||||
|
||||
| Criterion | Pass condition |
|
||||
|-----------|---------------|
|
||||
| **Coverage** | Every feature/change described in the PR has at least one test scenario |
|
||||
| **All scenarios pass** | No FAIL rows in the results table |
|
||||
| **Negative tests** | At least one failure-path test per feature (invalid input, unauthorized, edge case) |
|
||||
| **Before/after evidence** | Every state-changing API call has before/after values logged |
|
||||
| **Screenshots are meaningful** | Screenshots show the actual state change, not just a loading spinner or blank page |
|
||||
| **No regressions** | Existing core flows (login, agent create/run) still work |
|
||||
|
||||
### Decision logic
|
||||
|
||||
```
|
||||
ALL criteria pass → APPROVE
|
||||
Any scenario FAIL or missing PR feature → REQUEST_CHANGES (list gaps)
|
||||
Evidence weak (no before/after, vague shots) → REQUEST_CHANGES (list what's missing)
|
||||
```
|
||||
|
||||
### Post the review
|
||||
|
||||
```bash
|
||||
REVIEW_FILE=$(mktemp)
|
||||
|
||||
# Count results
|
||||
PASS_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "PASS" || true)
|
||||
FAIL_COUNT=$(echo "$TEST_RESULTS_TABLE" | grep -c "FAIL" || true)
|
||||
TOTAL=$(( PASS_COUNT + FAIL_COUNT ))
|
||||
|
||||
# List any coverage gaps found during evaluation (populate this array as you assess)
|
||||
# e.g. COVERAGE_GAPS=("PR claims to add X but no test covers it")
|
||||
COVERAGE_GAPS=()
|
||||
```
|
||||
|
||||
**If APPROVING** — all criteria met, zero failures, full coverage:
|
||||
|
||||
```bash
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — APPROVED
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed.
|
||||
|
||||
**Coverage:** All features described in the PR were exercised.
|
||||
|
||||
**Evidence:** Before/after API values logged for all state-changing operations; screenshots show meaningful state transitions.
|
||||
|
||||
**Negative tests:** Failure paths tested for each feature.
|
||||
|
||||
No regressions observed on core flows.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --approve --body "$(cat "$REVIEW_FILE")"
|
||||
echo "✅ PR approved"
|
||||
```
|
||||
|
||||
**If REQUESTING CHANGES** — any failure, coverage gap, or missing evidence:
|
||||
|
||||
```bash
|
||||
FAIL_LIST=$(echo "$TEST_RESULTS_TABLE" | grep "FAIL" | awk -F'|' '{print "- Scenario" $2 "failed"}' || true)
|
||||
|
||||
cat > "$REVIEW_FILE" <<REVIEWEOF
|
||||
## E2E Test Evaluation — Changes Requested
|
||||
|
||||
**Results:** ${PASS_COUNT}/${TOTAL} scenarios passed, ${FAIL_COUNT} failed.
|
||||
|
||||
### Required before merge
|
||||
|
||||
${FAIL_LIST}
|
||||
$(for gap in "${COVERAGE_GAPS[@]}"; do echo "- $gap"; done)
|
||||
|
||||
Please fix the above and re-run the E2E tests.
|
||||
REVIEWEOF
|
||||
|
||||
gh pr review "$PR_NUMBER" --repo "$REPO" --request-changes --body "$(cat "$REVIEW_FILE")"
|
||||
echo "❌ Changes requested"
|
||||
```
|
||||
|
||||
```bash
|
||||
rm -f "$REVIEW_FILE"
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- In `--fix` mode, fix all failures before posting the review — the review reflects the final state after fixes
|
||||
- Never approve if any scenario failed, even if it seems like a flake — rerun that scenario first
|
||||
- Never request changes for issues already fixed in this run
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, after finding a bug:
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
1. Identify the root cause in the code
|
||||
2. Fix it in the worktree
|
||||
3. Rebuild the affected service: `cd $PLATFORM_DIR && docker compose up --build -d {service_name}`
|
||||
4. Re-test the scenario
|
||||
5. If fix works, commit and push:
|
||||
### Fix protocol for EVERY issue found (including UX issues):
|
||||
|
||||
1. **Identify** the root cause in the code — read the relevant source files
|
||||
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
|
||||
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
|
||||
4. **Fix** the code in the worktree
|
||||
5. **Rebuild** ONLY the affected service (not the whole stack):
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
|
||||
# e.g., docker compose up --build -d rest_server
|
||||
# e.g., docker compose up --build -d frontend
|
||||
```
|
||||
6. **Wait** for the service to be ready (poll health endpoint)
|
||||
7. **Re-test** the same scenario
|
||||
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
|
||||
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
|
||||
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
|
||||
11. **Commit and push** immediately:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
6. Continue testing remaining scenarios
|
||||
7. After all fixes, run the full test suite again to ensure no regressions
|
||||
12. **Continue** to the next test scenario
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find bug → fix code → rebuild service → re-test
|
||||
→ repeat until all scenarios pass
|
||||
→ commit + push all fixes
|
||||
→ run full re-test to verify
|
||||
test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
→ fix code → rebuild affected service only → re-test → screenshot fixed state
|
||||
→ verify no regressions → commit + push
|
||||
→ repeat for next scenario
|
||||
→ after ALL scenarios pass, run full re-test to verify everything together
|
||||
```
|
||||
|
||||
**Key differences from non-fix mode:**
|
||||
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
|
||||
- Every fix MUST have a before/after screenshot pair proving it works
|
||||
- Commit after EACH fix, not in a batch at the end
|
||||
- The final re-test must produce a clean set of all-passing screenshots
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
|
||||
195
.claude/skills/setup-repo/SKILL.md
Normal file
195
.claude/skills/setup-repo/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: setup-repo
|
||||
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
|
||||
user-invocable: true
|
||||
args: "No arguments — interactive setup via prompts."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Repository Setup
|
||||
|
||||
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
|
||||
- A **main** worktree (the primary checkout)
|
||||
- A **reviews** worktree (for PR reviews)
|
||||
- **N work branches** (branch1..branchN) for parallel development
|
||||
|
||||
## Step 1: Identify the repo
|
||||
|
||||
Determine the repo root and parent directory:
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
REPO_NAME=$(basename "$ROOT")
|
||||
PARENT=$(dirname "$ROOT")
|
||||
```
|
||||
|
||||
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
|
||||
|
||||
```bash
|
||||
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
|
||||
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
|
||||
if [ "$SIBLING_COUNT" -gt 1 ]; then
|
||||
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
|
||||
# Use $ROOT as-is; skip renaming/restructuring
|
||||
else
|
||||
echo "INFO: Fresh clone detected, proceeding with setup"
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 2: Ask the user questions
|
||||
|
||||
Use AskUserQuestion to gather setup preferences:
|
||||
|
||||
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
|
||||
- These become `branch1` through `branchN`
|
||||
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
|
||||
- All work branches and reviews will start from this
|
||||
|
||||
## Step 3: Fetch and set up branches
|
||||
|
||||
```bash
|
||||
cd "$ROOT"
|
||||
git fetch origin
|
||||
|
||||
# Create the reviews branch from base (skip if already exists)
|
||||
if git show-ref --verify --quiet refs/heads/reviews; then
|
||||
echo "INFO: Branch 'reviews' already exists, skipping"
|
||||
else
|
||||
git branch reviews <base-branch>
|
||||
fi
|
||||
|
||||
# Create numbered work branches from base (skip if already exists)
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if git show-ref --verify --quiet "refs/heads/branch$i"; then
|
||||
echo "INFO: Branch 'branch$i' already exists, skipping"
|
||||
else
|
||||
git branch "branch$i" <base-branch>
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 4: Create worktrees
|
||||
|
||||
Create worktrees as siblings to the main checkout:
|
||||
|
||||
```bash
|
||||
if [ -d "$PARENT/reviews" ]; then
|
||||
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/reviews" reviews
|
||||
fi
|
||||
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if [ -d "$PARENT/branch$i" ]; then
|
||||
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/branch$i" "branch$i"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 5: Set up environment files
|
||||
|
||||
**Do NOT assume .env files exist.** For each worktree (including main if needed):
|
||||
|
||||
1. Check if `.env` exists in the source worktree for each path
|
||||
2. If `.env` exists, copy it
|
||||
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
|
||||
4. If neither exists, warn the user and list which env files are missing
|
||||
|
||||
Env file locations to check (same as the `/worktree` skill — keep these in sync):
|
||||
- `autogpt_platform/.env`
|
||||
- `autogpt_platform/backend/.env`
|
||||
- `autogpt_platform/frontend/.env`
|
||||
|
||||
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
|
||||
|
||||
```bash
|
||||
SOURCE="$ROOT"
|
||||
WORKTREES="reviews"
|
||||
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
|
||||
|
||||
FOUND_ANY_ENV=0
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
|
||||
if [ -f "$SOURCE/$envpath/.env" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
|
||||
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
|
||||
else
|
||||
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
|
||||
echo "WARNING: No environment files or templates were found in the source worktree."
|
||||
# Use AskUserQuestion to confirm: "Continue setup without env files?"
|
||||
# If the user declines, stop here and let them set up .env files first.
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 6: Copy branchlet config
|
||||
|
||||
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
|
||||
|
||||
```bash
|
||||
if [ -f "$ROOT/.branchlet.json" ]; then
|
||||
for wt in $WORKTREES; do
|
||||
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
|
||||
done
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 7: Install dependencies
|
||||
|
||||
Install deps in all worktrees. Run these sequentially per worktree:
|
||||
|
||||
```bash
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
echo "=== Installing deps for $wt ==="
|
||||
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
|
||||
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
|
||||
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
|
||||
echo "=== Done: $wt ===" ||
|
||||
echo "=== FAILED: $wt ==="
|
||||
done
|
||||
```
|
||||
|
||||
This is slow. Run in background if possible and notify when complete.
|
||||
|
||||
## Step 8: Verify and report
|
||||
|
||||
After setup, verify and report to the user:
|
||||
|
||||
```bash
|
||||
git worktree list
|
||||
```
|
||||
|
||||
Summarize:
|
||||
- Number of worktrees created
|
||||
- Which env files were copied vs created from defaults vs missing
|
||||
- Any warnings or errors encountered
|
||||
|
||||
## Final directory layout
|
||||
|
||||
```
|
||||
parent/
|
||||
main/ # Primary checkout (already exists)
|
||||
reviews/ # PR review worktree
|
||||
branch1/ # Work branch 1
|
||||
branch2/ # Work branch 2
|
||||
...
|
||||
branchN/ # Work branch N
|
||||
```
|
||||
224
.claude/skills/write-frontend-tests/SKILL.md
Normal file
224
.claude/skills/write-frontend-tests/SKILL.md
Normal file
@@ -0,0 +1,224 @@
|
||||
---
|
||||
name: write-frontend-tests
|
||||
description: "Analyze the current branch diff against dev, plan integration tests for changed frontend pages/components, and write them. TRIGGER when user asks to write frontend tests, add test coverage, or 'write tests for my changes'."
|
||||
user-invocable: true
|
||||
args: "[base branch] — defaults to dev. Optionally pass a specific base branch to diff against."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Write Frontend Tests
|
||||
|
||||
Analyze the current branch's frontend changes, plan integration tests, and write them.
|
||||
|
||||
## References
|
||||
|
||||
Before writing any tests, read the testing rules and conventions:
|
||||
|
||||
- `autogpt_platform/frontend/TESTING.md` — testing strategy, file locations, examples
|
||||
- `autogpt_platform/frontend/src/tests/AGENTS.md` — detailed testing rules, MSW patterns, decision flowchart
|
||||
- `autogpt_platform/frontend/src/tests/integrations/test-utils.tsx` — custom render with providers
|
||||
- `autogpt_platform/frontend/src/tests/integrations/vitest.setup.tsx` — MSW server setup
|
||||
|
||||
## Step 1: Identify changed frontend files
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${ARGUMENTS:-dev}"
|
||||
cd autogpt_platform/frontend
|
||||
|
||||
# Get changed frontend files (excluding generated, config, and test files)
|
||||
git diff "$BASE_BRANCH"...HEAD --name-only -- src/ \
|
||||
| grep -v '__generated__' \
|
||||
| grep -v '__tests__' \
|
||||
| grep -v '\.test\.' \
|
||||
| grep -v '\.stories\.' \
|
||||
| grep -v '\.spec\.'
|
||||
```
|
||||
|
||||
Also read the diff to understand what changed:
|
||||
|
||||
```bash
|
||||
git diff "$BASE_BRANCH"...HEAD --stat -- src/
|
||||
git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
```
|
||||
|
||||
## Step 2: Categorize changes and find test targets
|
||||
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Hooks with non-trivial business logic
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
|
||||
## Step 3: Check for existing tests
|
||||
|
||||
For each test target, check if tests already exist:
|
||||
|
||||
```bash
|
||||
# For a page at src/app/(platform)/library/page.tsx
|
||||
ls src/app/\(platform\)/library/__tests__/ 2>/dev/null
|
||||
|
||||
# For a component at src/app/(platform)/library/components/AgentCard/AgentCard.tsx
|
||||
ls src/app/\(platform\)/library/components/AgentCard/__tests__/ 2>/dev/null
|
||||
```
|
||||
|
||||
Note which targets have no tests (need new files) vs which have tests that need updating.
|
||||
|
||||
## Step 4: Identify API endpoints used
|
||||
|
||||
For each test target, find which API hooks are used:
|
||||
|
||||
```bash
|
||||
# Find generated API hook imports in the changed files
|
||||
grep -rn 'from.*__generated__/endpoints' src/app/\(platform\)/library/
|
||||
grep -rn 'use[A-Z].*V[12]' src/app/\(platform\)/library/
|
||||
```
|
||||
|
||||
For each API hook found, locate the corresponding MSW handler:
|
||||
|
||||
```bash
|
||||
# If the page uses useGetV2ListLibraryAgents, find its MSW handlers
|
||||
grep -rn 'getGetV2ListLibraryAgents.*Handler' src/app/api/__generated__/endpoints/library/library.msw.ts
|
||||
```
|
||||
|
||||
List every MSW handler you will need (200 for happy path, 4xx for error paths).
|
||||
|
||||
## Step 5: Write the test plan
|
||||
|
||||
Before writing code, output a plan as a numbered list:
|
||||
|
||||
```
|
||||
Test plan for [branch name]:
|
||||
|
||||
1. src/app/(platform)/library/__tests__/main.test.tsx (NEW)
|
||||
- Renders page with agent list (MSW 200)
|
||||
- Shows loading state
|
||||
- Shows error state (MSW 422)
|
||||
- Handles empty agent list
|
||||
|
||||
2. src/app/(platform)/library/__tests__/search.test.tsx (NEW)
|
||||
- Filters agents by search query
|
||||
- Shows no results message
|
||||
- Clears search
|
||||
|
||||
3. src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx (UPDATE)
|
||||
- Add test for new "duplicate" action
|
||||
```
|
||||
|
||||
Present this plan to the user. Wait for confirmation before proceeding. If the user has feedback, adjust the plan.
|
||||
|
||||
## Step 6: Write the tests
|
||||
|
||||
For each test file in the plan, follow these conventions:
|
||||
|
||||
### File structure
|
||||
|
||||
```tsx
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
// Import MSW handlers for endpoints the page uses
|
||||
import {
|
||||
getGetV2ListLibraryAgentsMockHandler200,
|
||||
getGetV2ListLibraryAgentsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/library/library.msw";
|
||||
// Import the component under test
|
||||
import LibraryPage from "../page";
|
||||
|
||||
describe("LibraryPage", () => {
|
||||
test("renders agent list from API", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler200());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/my agents/i)).toBeDefined();
|
||||
});
|
||||
|
||||
test("shows error state on API failure", async () => {
|
||||
server.use(getGetV2ListLibraryAgentsMockHandler422());
|
||||
|
||||
render(<LibraryPage />);
|
||||
|
||||
expect(await screen.findByText(/error/i)).toBeDefined();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Rules
|
||||
|
||||
- Use `render()` from `@/tests/integrations/test-utils` (NOT from `@testing-library/react` directly)
|
||||
- Use `server.use()` to set up MSW handlers BEFORE rendering
|
||||
- Use `findBy*` (async) for elements that appear after data fetching — NOT `getBy*`
|
||||
- Use `getBy*` only for elements that are immediately present in the DOM
|
||||
- Use `screen` queries — do NOT destructure from `render()`
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
|
||||
### Test location
|
||||
|
||||
```
|
||||
# For pages: __tests__/ next to page.tsx
|
||||
src/app/(platform)/library/__tests__/main.test.tsx
|
||||
|
||||
# For complex standalone components: __tests__/ inside component folder
|
||||
src/app/(platform)/library/components/AgentCard/__tests__/AgentCard.test.tsx
|
||||
|
||||
# For pure helpers: co-located .test.ts
|
||||
src/app/(platform)/library/helpers.test.ts
|
||||
```
|
||||
|
||||
### Custom MSW overrides
|
||||
|
||||
When the auto-generated faker data is not enough, override with specific data:
|
||||
|
||||
```tsx
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [
|
||||
{ id: "1", name: "Test Agent", description: "A test agent" },
|
||||
],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
);
|
||||
```
|
||||
|
||||
Use the proxy URL pattern: `http://localhost:3000/api/proxy/api/v{version}/{path}` — this matches the MSW base URL configured in `orval.config.ts`.
|
||||
|
||||
## Step 7: Run and verify
|
||||
|
||||
After writing all tests:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
Then run the full checks:
|
||||
|
||||
```bash
|
||||
pnpm format
|
||||
pnpm lint
|
||||
pnpm types
|
||||
```
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,8 +1,12 @@
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
78
.github/workflows/classic-autogpt-ci.yml
vendored
78
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,11 +6,19 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -19,47 +27,22 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/original_autogpt
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -71,41 +54,23 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -116,12 +81,13 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -135,11 +101,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
flags: autogpt-agent
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/original_autogpt/logs/
|
||||
path: classic/logs/
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
--entrypoint poetry ${{ env.IMAGE_NAME }} run \
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
--numprocesses=4 --durations=10 \
|
||||
tests/unit tests/integration 2>&1 | tee test_output.txt
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration 2>&1 | tee test_output.txt
|
||||
|
||||
test_failure=${PIPESTATUS[0]}
|
||||
|
||||
|
||||
44
.github/workflows/classic-autogpts-ci.yml
vendored
44
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -10,10 +10,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -21,10 +20,9 @@ on:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -35,13 +33,9 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,22 +49,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
|
||||
256
.github/workflows/classic-benchmark-ci.yml
vendored
256
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,18 +1,24 @@
|
||||
name: Classic - AGBenchmark CI
|
||||
name: Classic - Direct Benchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('benchmark-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -23,95 +29,16 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/benchmark
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
self-test-with-agent:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -124,53 +51,120 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run basic benchmark tests
|
||||
run: |
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
189
.github/workflows/classic-forge-ci.yml
vendored
189
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,13 +6,15 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -21,131 +23,60 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/forge
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: poetry run playwright install chromium
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
forge/forge forge/tests
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -159,85 +90,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
flags: forge
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/forge/logs/
|
||||
path: classic/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
60
.github/workflows/classic-frontend-ci.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -16,7 +18,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -27,44 +31,13 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -81,42 +54,31 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -133,19 +95,16 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
20
.github/workflows/platform-backend-ci.yml
vendored
20
.github/workflows/platform-backend-ci.yml
vendored
@@ -269,12 +269,14 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- name: Run pytest
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
poetry run pytest -s -vv \
|
||||
--cov=backend --cov-branch --cov-report term-missing --cov-report xml
|
||||
fi
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
@@ -287,11 +289,13 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-backend
|
||||
files: ./autogpt_platform/backend/coverage.xml
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
8
.github/workflows/platform-frontend-ci.yml
vendored
8
.github/workflows/platform-frontend-ci.yml
vendored
@@ -148,3 +148,11 @@ jobs:
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend
|
||||
files: ./autogpt_platform/frontend/coverage/cobertura-coverage.xml
|
||||
|
||||
25
.github/workflows/platform-fullstack-ci.yml
vendored
25
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -179,21 +179,30 @@ jobs:
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
export NEXT_PUBLIC_SOURCEMAPS NEXT_PUBLIC_PW_TEST
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Ensure NEXT_PUBLIC_SOURCEMAPS is in resolved compose
|
||||
# (docker compose config on some versions drops this arg)
|
||||
if ! grep -q "NEXT_PUBLIC_SOURCEMAPS" docker-compose.resolved.yml; then
|
||||
echo "Injecting NEXT_PUBLIC_SOURCEMAPS into resolved compose (docker compose config dropped it)"
|
||||
sed -i '/NEXT_PUBLIC_PW_TEST/a\ NEXT_PUBLIC_SOURCEMAPS: "true"' docker-compose.resolved.yml
|
||||
fi
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}-sourcemaps" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
NEXT_PUBLIC_SOURCEMAPS: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
@@ -279,6 +288,11 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
docker cp "$FRONTEND_CONTAINER":/app/.next/static .next-static-coverage
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -289,6 +303,15 @@ jobs:
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
if: ${{ !cancelled() }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: platform-frontend-e2e
|
||||
files: ./autogpt_platform/frontend/coverage/e2e/cobertura-coverage.xml
|
||||
disable_search: true
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -16,6 +17,7 @@ log-ingestion.txt
|
||||
/logs
|
||||
*.log
|
||||
*.mp3
|
||||
!autogpt_platform/frontend/public/notification.mp3
|
||||
mem.sqlite3
|
||||
venvAutoGPT
|
||||
|
||||
@@ -159,6 +161,10 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -177,9 +183,13 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
|
||||
36
.gitleaks.toml
Normal file
36
.gitleaks.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
title = "AutoGPT Gitleaks Config"
|
||||
|
||||
[extend]
|
||||
useDefault = true
|
||||
|
||||
[allowlist]
|
||||
description = "Global allowlist"
|
||||
paths = [
|
||||
# Template/example env files (no real secrets)
|
||||
'''\.env\.(default|example|template)$''',
|
||||
# Lock files
|
||||
'''pnpm-lock\.yaml$''',
|
||||
'''poetry\.lock$''',
|
||||
# Secrets baseline
|
||||
'''\.secrets\.baseline$''',
|
||||
# Build artifacts and caches (should not be committed)
|
||||
'''__pycache__/''',
|
||||
'''classic/frontend/build/''',
|
||||
# Docker dev setup (local dev JWTs/keys only)
|
||||
'''autogpt_platform/db/docker/''',
|
||||
# Load test configs (dev JWTs)
|
||||
'''load-tests/configs/''',
|
||||
# Test files with fake/fixture keys (_test.py, test_*.py, conftest.py)
|
||||
'''(_test|test_.*|conftest)\.py$''',
|
||||
# Documentation (only contains placeholder keys in curl/API examples)
|
||||
'''docs/.*\.md$''',
|
||||
# Firebase config (public API keys by design)
|
||||
'''google-services\.json$''',
|
||||
'''classic/frontend/(lib|web)/''',
|
||||
]
|
||||
# CI test-only encryption key (marked DO NOT USE IN PRODUCTION)
|
||||
regexes = [
|
||||
'''dvziYgz0KSK8FENhju0ZYi8''',
|
||||
# LLM model name enum values falsely flagged as API keys
|
||||
'''Llama-\d.*Instruct''',
|
||||
]
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -23,9 +23,15 @@ repos:
|
||||
- id: detect-secrets
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
args: ["--baseline", ".secrets.baseline"]
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
exclude: (pnpm-lock\.yaml|\.env\.(default|example|template))$
|
||||
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
name: Detect secrets (gitleaks)
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, all dependencies need to be up-to-date.
|
||||
@@ -84,51 +90,16 @@ repos:
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
fi | grep -qE "^classic/poetry\.lock$" || exit 0;
|
||||
poetry -C classic install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
@@ -223,26 +194,10 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -256,26 +211,13 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -311,29 +253,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -360,26 +283,9 @@ repos:
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
# alias: pytest-classic-autogpt
|
||||
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# # include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Forge (excl. slow tests)
|
||||
# alias: pytest-classic-forge
|
||||
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Benchmark
|
||||
# alias: pytest-classic-benchmark
|
||||
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
# name: Run tests - Classic (excl. slow tests)
|
||||
# alias: pytest-classic
|
||||
# entry: bash -c 'cd classic && poetry run pytest -m "not slow"'
|
||||
# files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
|
||||
467
.secrets.baseline
Normal file
467
.secrets.baseline
Normal file
@@ -0,0 +1,467 @@
|
||||
{
|
||||
"version": "1.5.0",
|
||||
"plugins_used": [
|
||||
{
|
||||
"name": "ArtifactoryDetector"
|
||||
},
|
||||
{
|
||||
"name": "AWSKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "AzureStorageKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "Base64HighEntropyString",
|
||||
"limit": 4.5
|
||||
},
|
||||
{
|
||||
"name": "BasicAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "CloudantDetector"
|
||||
},
|
||||
{
|
||||
"name": "DiscordBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitHubTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "GitLabTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "HexHighEntropyString",
|
||||
"limit": 3.0
|
||||
},
|
||||
{
|
||||
"name": "IbmCloudIamDetector"
|
||||
},
|
||||
{
|
||||
"name": "IbmCosHmacDetector"
|
||||
},
|
||||
{
|
||||
"name": "IPPublicDetector"
|
||||
},
|
||||
{
|
||||
"name": "JwtTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "KeywordDetector",
|
||||
"keyword_exclude": ""
|
||||
},
|
||||
{
|
||||
"name": "MailchimpDetector"
|
||||
},
|
||||
{
|
||||
"name": "NpmDetector"
|
||||
},
|
||||
{
|
||||
"name": "OpenAIDetector"
|
||||
},
|
||||
{
|
||||
"name": "PrivateKeyDetector"
|
||||
},
|
||||
{
|
||||
"name": "PypiTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "SendGridDetector"
|
||||
},
|
||||
{
|
||||
"name": "SlackDetector"
|
||||
},
|
||||
{
|
||||
"name": "SoftlayerDetector"
|
||||
},
|
||||
{
|
||||
"name": "SquareOAuthDetector"
|
||||
},
|
||||
{
|
||||
"name": "StripeDetector"
|
||||
},
|
||||
{
|
||||
"name": "TelegramBotTokenDetector"
|
||||
},
|
||||
{
|
||||
"name": "TwilioKeyDetector"
|
||||
}
|
||||
],
|
||||
"filters_used": [
|
||||
{
|
||||
"path": "detect_secrets.filters.allowlist.is_line_allowlisted"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
|
||||
"min_level": 2
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_indirect_reference"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_likely_id_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_lock_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_potential_uuid"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_sequential_string"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_swagger_file"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.heuristic.is_templated_secret"
|
||||
},
|
||||
{
|
||||
"path": "detect_secrets.filters.regex.should_exclude_file",
|
||||
"pattern": [
|
||||
"\\.env$",
|
||||
"pnpm-lock\\.yaml$",
|
||||
"\\.env\\.(default|example|template)$",
|
||||
"__pycache__",
|
||||
"_test\\.py$",
|
||||
"test_.*\\.py$",
|
||||
"conftest\\.py$",
|
||||
"poetry\\.lock$",
|
||||
"node_modules"
|
||||
]
|
||||
}
|
||||
],
|
||||
"results": {
|
||||
"autogpt_platform/backend/backend/api/external/v1/integrations.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/api/external/v1/integrations.py",
|
||||
"hashed_secret": "665b1e3851eefefa3fb878654292f16597d25155",
|
||||
"is_verified": false,
|
||||
"line_number": 289
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/airtable/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/airtable/_config.py",
|
||||
"hashed_secret": "57e168b03afb7c1ee3cdc4ee3db2fe1cc6e0df26",
|
||||
"is_verified": false,
|
||||
"line_number": 29
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/dataforseo/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/dataforseo/_config.py",
|
||||
"hashed_secret": "32ce93887331fa5d192f2876ea15ec000c7d58b8",
|
||||
"is_verified": false,
|
||||
"line_number": 12
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/checks.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/checks.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 108
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/ci.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/ci.py",
|
||||
"hashed_secret": "90bd1b48e958257948487b90bee080ba5ed00caa",
|
||||
"is_verified": false,
|
||||
"line_number": 123
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "f96896dafced7387dcd22343b8ea29d3d2c65663",
|
||||
"is_verified": false,
|
||||
"line_number": 42
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b80a94d5e70bedf4f5f89d2f5a5255cc9492d12e",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "75b17e517fe1b3136394f6bec80c4f892da75e42",
|
||||
"is_verified": false,
|
||||
"line_number": 344
|
||||
},
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/example_payloads/pull_request.synchronize.json",
|
||||
"hashed_secret": "b0bfb5e4e2394e7f8906e5ed1dffd88b2bc89dd5",
|
||||
"is_verified": false,
|
||||
"line_number": 534
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/github/statuses.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/github/statuses.py",
|
||||
"hashed_secret": "8ac6f92737d8586790519c5d7bfb4d2eb172c238",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/docs.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/docs.py",
|
||||
"hashed_secret": "c95da0c6696342c867ef0c8258d2f74d20fd94d4",
|
||||
"is_verified": false,
|
||||
"line_number": 203
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/google/sheets.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/google/sheets.py",
|
||||
"hashed_secret": "bd5a04fa3667e693edc13239b6d310c5c7a8564b",
|
||||
"is_verified": false,
|
||||
"line_number": 57
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/linear/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/linear/_config.py",
|
||||
"hashed_secret": "b37f020f42d6d613b6ce30103e4d408c4499b3bb",
|
||||
"is_verified": false,
|
||||
"line_number": 53
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/medium.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/medium.py",
|
||||
"hashed_secret": "ff998abc1ce6d8f01a675fa197368e44c8916e9c",
|
||||
"is_verified": false,
|
||||
"line_number": 131
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/replicate/replicate_block.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/replicate/replicate_block.py",
|
||||
"hashed_secret": "8bbdd6f26368f58ea4011d13d7f763cb662e66f0",
|
||||
"is_verified": false,
|
||||
"line_number": 55
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/slant3d/webhook.py": [
|
||||
{
|
||||
"type": "Hex High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/slant3d/webhook.py",
|
||||
"hashed_secret": "36263c76947443b2f6e6b78153967ac4a7da99f9",
|
||||
"is_verified": false,
|
||||
"line_number": 100
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/talking_head.py": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/talking_head.py",
|
||||
"hashed_secret": "44ce2d66222529eea4a32932823466fc0601c799",
|
||||
"is_verified": false,
|
||||
"line_number": 113
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/blocks/wordpress/_config.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/blocks/wordpress/_config.py",
|
||||
"hashed_secret": "e62679512436161b78e8a8d68c8829c2a1031ccb",
|
||||
"is_verified": false,
|
||||
"line_number": 17
|
||||
}
|
||||
],
|
||||
"autogpt_platform/backend/backend/util/cache.py": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/backend/backend/util/cache.py",
|
||||
"hashed_secret": "37f0c918c3fa47ca4a70e42037f9f123fdfbc75b",
|
||||
"is_verified": false,
|
||||
"line_number": 449
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/helpers.ts",
|
||||
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/en.json",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/dictionaries/es.json",
|
||||
"hashed_secret": "5a6d1c612954979ea99ee33dbb2d231b00f6ac0a",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 6
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/AgentInputsReadOnly/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 8
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 5
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/ModalRunSection/helpers.ts",
|
||||
"hashed_secret": "f72cbb45464d487064610c5411c576ca4019d380",
|
||||
"is_verified": false,
|
||||
"line_number": 7
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "cf678cab87dc1f7d1b95b964f15375e088461679",
|
||||
"is_verified": false,
|
||||
"line_number": 192
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/app/(platform)/profile/(user)/integrations/page.tsx",
|
||||
"hashed_secret": "86275db852204937bbdbdebe5fabe8536e030ab6",
|
||||
"is_verified": false,
|
||||
"line_number": 193
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "47acd2028cf81b5da88ddeedb2aea4eca4b71fbd",
|
||||
"is_verified": false,
|
||||
"line_number": 102
|
||||
},
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/components/contextual/CredentialsInput/helpers.ts",
|
||||
"hashed_secret": "8be3c943b1609fffbfc51aad666d0a04adf83c9d",
|
||||
"is_verified": false,
|
||||
"line_number": 103
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts": [
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9c486c92f1a7420e1045c7ad963fbb7ba3621025",
|
||||
"is_verified": false,
|
||||
"line_number": 73
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "9277508c7a6effc8fb59163efbfada189e35425c",
|
||||
"is_verified": false,
|
||||
"line_number": 75
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "8dc7e2cb1d0935897d541bf5facab389b8a50340",
|
||||
"is_verified": false,
|
||||
"line_number": 77
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "79a26ad48775944299be6aaf9fb1d5302c1ed75b",
|
||||
"is_verified": false,
|
||||
"line_number": 79
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a3b62b44500a1612e48d4cab8294df81561b3b1a",
|
||||
"is_verified": false,
|
||||
"line_number": 81
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "a58979bd0b21ef4f50417d001008e60dd7a85c64",
|
||||
"is_verified": false,
|
||||
"line_number": 83
|
||||
},
|
||||
{
|
||||
"type": "Base64 High Entropy String",
|
||||
"filename": "autogpt_platform/frontend/src/lib/autogpt-server-api/utils.ts",
|
||||
"hashed_secret": "6cb6e075f8e8c7c850f9d128d6608e5dbe209a79",
|
||||
"is_verified": false,
|
||||
"line_number": 85
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/lib/constants.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/lib/constants.ts",
|
||||
"hashed_secret": "27b924db06a28cc755fb07c54f0fddc30659fe4d",
|
||||
"is_verified": false,
|
||||
"line_number": 10
|
||||
}
|
||||
],
|
||||
"autogpt_platform/frontend/src/tests/credentials/index.ts": [
|
||||
{
|
||||
"type": "Secret Keyword",
|
||||
"filename": "autogpt_platform/frontend/src/tests/credentials/index.ts",
|
||||
"hashed_secret": "c18006fc138809314751cd1991f1e0b820fabd37",
|
||||
"is_verified": false,
|
||||
"line_number": 4
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-04-02T13:10:54Z"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
@@ -30,7 +30,7 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
5. **Testing**: Integration tests (Vitest + RTL + MSW) are the default (~90%, page-level). Playwright for E2E critical flows. Storybook for design system components. See `autogpt_platform/frontend/TESTING.md`
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
@@ -47,7 +47,9 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
- Frontend integration tests: `pnpm test:unit` (Vitest + RTL + MSW, primary testing approach).
|
||||
- Frontend E2E tests: `pnpm test` or `pnpm test-ui` for Playwright tests.
|
||||
- See `autogpt_platform/frontend/TESTING.md` for the full testing strategy.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
120
autogpt_platform/AGENTS.md
Normal file
120
autogpt_platform/AGENTS.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1,118 +1 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@AGENTS.md
|
||||
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.0.0"
|
||||
version = "7.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.0"
|
||||
version = "0.15.7"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -178,6 +178,7 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
227
autogpt_platform/backend/AGENTS.md
Normal file
227
autogpt_platform/backend/AGENTS.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1,226 +1 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@AGENTS.md
|
||||
|
||||
@@ -18,14 +18,22 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.api.features.integrations.router import (
|
||||
CredentialsMetaResponse,
|
||||
to_meta_response,
|
||||
)
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
@@ -91,18 +99,6 @@ class OAuthCompleteResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
@@ -473,12 +469,12 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
@@ -486,28 +482,19 @@ async def list_credentials(
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
@@ -515,16 +502,7 @@ async def list_credentials_by_provider(
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -597,11 +575,11 @@ async def create_credential(
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
@@ -639,15 +617,23 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["copilot", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
user_id: Optional[str], email: Optional[str]
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Resolve a user_id and email from the provided parameters.
|
||||
|
||||
Returns (user_id, email). Accepts either user_id or email; at least one
|
||||
must be provided. When both are provided, ``email`` takes precedence.
|
||||
"""
|
||||
if email:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No user found with the provided email."
|
||||
)
|
||||
return user.id, email
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either user_id or email query parameter is required.",
|
||||
)
|
||||
|
||||
# We have a user_id; try to look up their email for display purposes.
|
||||
# This is non-critical -- a failure should not block the response.
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
return user_id, resolved_email
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Get User Rate Limit",
|
||||
)
|
||||
async def get_user_rate_limit(
|
||||
user_id: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Get a user's current usage and effective rate limits. Admin-only.
|
||||
|
||||
Accepts either ``user_id`` or ``email`` as a query parameter.
|
||||
When ``email`` is provided the user is looked up by email first.
|
||||
"""
|
||||
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/reset",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Reset User Rate Limit Usage",
|
||||
)
|
||||
async def reset_user_rate_limit(
|
||||
user_id: str = Body(embed=True),
|
||||
reset_weekly: bool = Body(False, embed=True),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
|
||||
logger.info(
|
||||
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
|
||||
admin_user_id,
|
||||
user_id,
|
||||
reset_weekly,
|
||||
)
|
||||
|
||||
try:
|
||||
await reset_user_usage(user_id, reset_weekly=reset_weekly)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s",
|
||||
request.user_id,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved_email = None
|
||||
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
|
||||
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email,
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
@@ -0,0 +1,566 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(rate_limit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
|
||||
|
||||
_TARGET_EMAIL = "target@example.com"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_usage_status(
|
||||
daily_used: int = 500_000, weekly_used: int = 3_000_000
|
||||
) -> CoPilotUsageStatus:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
now = datetime.now(UTC)
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _patch_rate_limit_deps(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
daily_used: int = 500_000,
|
||||
weekly_used: int = 3_000_000,
|
||||
):
|
||||
"""Patch the common rate-limit + user-lookup dependencies."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting rate limit and usage for a user."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"get_rate_limit",
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test looking up rate limits via email instead of user_id."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test that looking up a non-existent email returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_rate_limit_no_params() -> None:
|
||||
"""Test that omitting both user_id and email returns 400."""
|
||||
response = client.get("/admin/rate_limit")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting only daily usage (default behaviour)."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_only",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting both daily and weekly usage."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id, "reset_weekly": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_and_weekly",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_redis_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that Redis failure on reset returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Redis connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_rate_limit_email_lookup_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that failing to resolve a user email degrades gracefully."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection lost"),
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] is None
|
||||
|
||||
|
||||
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that rate limit admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure returns 404 (user unverifiable)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
@@ -7,6 +7,8 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
@@ -132,3 +134,40 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/{store_listing_version_id}/preview",
|
||||
summary="Admin Preview Submission Listing",
|
||||
)
|
||||
async def admin_preview_submission(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Preview a marketplace submission as it would appear on the listing page.
|
||||
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
|
||||
submissions before approving.
|
||||
"""
|
||||
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/add-to-library",
|
||||
summary="Admin Add Pending Agent to Library",
|
||||
status_code=201,
|
||||
)
|
||||
async def admin_add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add a pending marketplace agent to the admin's library for review.
|
||||
Uses admin-level access to bypass marketplace APPROVED-only checks.
|
||||
|
||||
The builder can load the graph because get_graph() checks library
|
||||
membership as a fallback: "you added it, you keep it."
|
||||
"""
|
||||
return await library_db.add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
"""Tests for admin store routes and the bypass logic they depend on.
|
||||
|
||||
Tests are organized by what they protect:
|
||||
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
|
||||
- SECRT-2167 security: admin endpoints reject non-admin users
|
||||
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
|
||||
and add-to-library uses get_graph_as_admin (not get_graph)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .store_admin_routes import router as store_admin_router
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
SLV_ID = "test-store-listing-version-id"
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""get_graph_as_admin must return a graph even when the admin doesn't own
|
||||
it and it's not APPROVED in the marketplace."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""get_graph_as_admin with for_export=True must call get_sub_graphs
|
||||
and pass sub_graphs to GraphModel.from_db."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
|
||||
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(store_admin_router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all route tests in this module."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_preview_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the preview endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the add-to-library endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_preview_nonexistent_submission(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Preview of a nonexistent submission returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.store_admin_routes.store_db"
|
||||
".get_store_agent_details_as_admin",
|
||||
side_effect=NotFoundError("not found"),
|
||||
)
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
|
||||
"""get_store_agent_details_as_admin must query StoreListingVersion
|
||||
directly (not the APPROVED-only StoreAgent view). This is THE test that
|
||||
prevents the bypass from being accidentally reverted."""
|
||||
from backend.api.features.store.db import get_store_agent_details_as_admin
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.id = SLV_ID
|
||||
mock_slv.name = "Test Agent"
|
||||
mock_slv.subHeading = "Short desc"
|
||||
mock_slv.description = "Long desc"
|
||||
mock_slv.videoUrl = None
|
||||
mock_slv.agentOutputDemoUrl = None
|
||||
mock_slv.imageUrls = ["https://example.com/img.png"]
|
||||
mock_slv.instructions = None
|
||||
mock_slv.categories = ["productivity"]
|
||||
mock_slv.version = 1
|
||||
mock_slv.agentGraphId = GRAPH_ID
|
||||
mock_slv.agentGraphVersion = GRAPH_VERSION
|
||||
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
|
||||
mock_slv.recommendedScheduleCron = "0 9 * * *"
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = "listing-id"
|
||||
mock_listing.slug = "test-agent"
|
||||
mock_listing.activeVersionId = SLV_ID
|
||||
mock_listing.hasApprovedVersion = False
|
||||
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
|
||||
mock_slv.StoreListing = mock_listing
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
|
||||
) as mock_store_agent_prisma,
|
||||
):
|
||||
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await get_store_agent_details_as_admin(SLV_ID)
|
||||
|
||||
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
|
||||
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
|
||||
await_args = mock_slv_prisma.return_value.find_unique.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["where"] == {"id": SLV_ID}
|
||||
|
||||
# Verify the APPROVED-only StoreAgent view was NOT touched
|
||||
mock_store_agent_prisma.assert_not_called()
|
||||
|
||||
# Verify the result has the right data
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_image == ["https://example.com/img.png"]
|
||||
assert result.has_approved_version is False
|
||||
assert result.runs == 0
|
||||
assert result.rating == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
|
||||
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
|
||||
not get_graph. This is THE test that prevents the add-to-library bypass
|
||||
from being accidentally reverted."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_admin.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
|
||||
)
|
||||
mock_regular.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_regular_uses_get_graph() -> None:
|
||||
"""resolve_graph_for_library(admin=False) must call get_graph,
|
||||
not get_graph_as_admin. Ensures the non-admin path is preserved."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_regular.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
|
||||
)
|
||||
mock_admin.assert_not_awaited()
|
||||
|
||||
|
||||
# ---- Library membership grants graph access (product decision) ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_member_can_view_pending_agent_in_builder() -> None:
|
||||
"""After adding a pending agent to their library, the user should be
|
||||
able to load the graph in the builder via get_graph()."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library membership should grant graph access"
|
||||
@@ -11,15 +11,17 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionMetadata,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
@@ -30,8 +32,14 @@ from backend.copilot.model import (
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
get_daily_reset_count,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
increment_daily_reset_count,
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
@@ -59,9 +67,16 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -69,8 +84,6 @@ _UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -99,6 +112,23 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: CopilotMode | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -107,6 +137,7 @@ class CreateSessionResponse(BaseModel):
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
@@ -125,8 +156,11 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
has_more_messages: bool = False
|
||||
oldest_sequence: int | None = None
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -237,6 +271,7 @@ async def list_sessions(
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -245,22 +280,28 @@ async def create_session(
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -356,59 +397,78 @@ async def update_session_title_route(
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
before_sequence: int | None = Query(default=None, ge=0),
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
Supports cursor-based pagination via ``limit`` and ``before_sequence``.
|
||||
When no pagination params are provided, returns the most recent messages.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of messages to return (1-200, default 50).
|
||||
before_sequence: Return messages with sequence < this value (cursor).
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
SessionDetailResponse: Details for the requested session, including
|
||||
active_stream info and pagination metadata.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
page = await get_chat_messages_paginated(
|
||||
session_id, limit, before_sequence, user_id=user_id
|
||||
)
|
||||
if page is None:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
messages = [message.model_dump() for message in page.messages]
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
# Only check active stream on initial load (not on "load more" requests)
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
if before_sequence is None:
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Skip session metadata on "load more" — frontend only needs messages
|
||||
if before_sequence is not None:
|
||||
return SessionDetailResponse(
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=None,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
total_prompt = sum(u.prompt_tokens for u in page.session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in page.session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
id=page.session.session_id,
|
||||
created_at=page.session.started_at.isoformat(),
|
||||
updated_at=page.session.updated_at.isoformat(),
|
||||
user_id=page.session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
has_more_messages=page.has_more,
|
||||
oldest_sequence=page.oldest_sequence,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=page.session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -421,11 +481,193 @@ async def get_copilot_usage(
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
"""Response from resetting the daily rate limit."""
|
||||
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/reset",
|
||||
status_code=200,
|
||||
responses={
|
||||
400: {
|
||||
"description": "Bad Request (feature disabled or daily limit not reached)"
|
||||
},
|
||||
402: {"description": "Payment Required (insufficient credits)"},
|
||||
429: {
|
||||
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
|
||||
},
|
||||
503: {
|
||||
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
|
||||
},
|
||||
},
|
||||
)
|
||||
async def reset_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
"""
|
||||
cost = config.rate_limit_reset_cost
|
||||
if cost <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available.",
|
||||
)
|
||||
|
||||
if not settings.config.enable_credit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No daily limit is configured — nothing to reset.",
|
||||
)
|
||||
|
||||
# Check max daily resets. get_daily_reset_count returns None when Redis
|
||||
# is unavailable; reject the reset in that case to prevent unlimited
|
||||
# free resets when the counter store is down.
|
||||
reset_count = await get_daily_reset_count(user_id)
|
||||
if reset_count is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Unable to verify reset eligibility — please try again later.",
|
||||
)
|
||||
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"You've used all {config.max_daily_resets} resets for today.",
|
||||
)
|
||||
|
||||
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
|
||||
if not await acquire_reset_lock(user_id):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="A reset is already in progress. Please try again.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
# (rate_limit_reset_cost intentionally omitted — this object is only
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="You have not reached your daily limit yet.",
|
||||
)
|
||||
|
||||
# If the weekly limit is also exhausted, resetting the daily counter
|
||||
# won't help — the user would still be blocked by the weekly limit.
|
||||
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
|
||||
)
|
||||
|
||||
# Charge credits.
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
try:
|
||||
remaining = await credit_model.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
reason="CoPilot daily rate limit reset",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient credits to reset your rate limit.",
|
||||
) from e
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
await credit_model.top_up_credits(user_id, cost)
|
||||
refunded = True
|
||||
logger.warning(
|
||||
"Refunded %d credits to user %s after Redis reset failure",
|
||||
cost,
|
||||
user_id[:8],
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"CRITICAL: Failed to refund %d credits to user %s "
|
||||
"after Redis reset failure — manual intervention required",
|
||||
cost,
|
||||
user_id[:8],
|
||||
exc_info=True,
|
||||
)
|
||||
if refunded:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed — please try again later. "
|
||||
"Your credits have not been charged.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed and the automatic refund "
|
||||
"also failed. Please contact support for assistance.",
|
||||
)
|
||||
|
||||
# Track the reset count for daily cap enforcement.
|
||||
await increment_daily_reset_count(user_id)
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status.
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=updated_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -526,12 +768,16 @@ async def stream_chat_post(
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
@@ -620,6 +866,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
@@ -894,6 +1141,47 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedTheme(BaseModel):
|
||||
"""A themed group of suggested prompts."""
|
||||
|
||||
name: str
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts grouped by theme."""
|
||||
|
||||
themes: list[SuggestedTheme]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts grouped by theme.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns empty themes list if no custom
|
||||
prompts are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None or not understanding.suggested_prompts:
|
||||
return SuggestedPromptsResponse(themes=[])
|
||||
|
||||
themes = [
|
||||
SuggestedTheme(name=name, prompts=prompts)
|
||||
for name, prompts in understanding.suggested_prompts.items()
|
||||
]
|
||||
return SuggestedPromptsResponse(themes=themes)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@@ -942,7 +1230,7 @@ async def health_check() -> dict:
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
session = await create_chat_session(health_check_user_id, dry_run=False)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -331,14 +332,28 @@ def _mock_usage(
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
|
||||
``get_usage_status`` so that tests exercise the endpoint without hitting
|
||||
LaunchDarkly or Prisma.
|
||||
"""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
@@ -368,6 +383,8 @@ def test_usage_returns_daily_and_weekly(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -375,11 +392,10 @@ def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
|
||||
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
@@ -388,6 +404,8 @@ def test_usage_uses_config_limits(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -400,3 +418,164 @@ def test_usage_rejects_unauthenticated_request() -> None:
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_themes(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with themed prompts gets them back as themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {
|
||||
"Learn": ["L1", "L2"],
|
||||
"Create": ["C1"],
|
||||
}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "themes" in data
|
||||
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
|
||||
assert themes_by_name["Learn"] == ["L1", "L2"]
|
||||
assert themes_by_name["Create"] == ["C1"]
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty themes list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but empty prompts gets empty themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
# ─── Create session: dry_run contract ─────────────────────────────────
|
||||
|
||||
|
||||
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock create_chat_session to return a fake session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool):
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_dry_run_true(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
|
||||
|
||||
def test_create_session_dry_run_default_false(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Empty body defaults dry_run to False."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_rejects_nested_metadata(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
|
||||
default to ``dry_run=False``. This guards against the common mistake of
|
||||
nesting dry_run inside metadata instead of providing it at the top level."""
|
||||
response = client.post(
|
||||
"/sessions",
|
||||
json={"metadata": {"dry_run": True}},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestStreamChatRequestModeValidation:
|
||||
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
|
||||
|
||||
def test_rejects_invalid_mode_value(self) -> None:
|
||||
"""Any string outside the Literal set must raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
|
||||
|
||||
def test_accepts_fast_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="fast")
|
||||
assert req.mode == "fast"
|
||||
|
||||
def test_accepts_extended_thinking_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="extended_thinking")
|
||||
assert req.mode == "extended_thinking"
|
||||
|
||||
def test_accepts_none_mode(self) -> None:
|
||||
"""``mode=None`` is valid (server decides via feature flags)."""
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode=None)
|
||||
assert req.mode is None
|
||||
|
||||
def test_mode_defaults_to_none_when_omitted(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Override session-scoped fixtures so unit tests run without the server."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup():
|
||||
yield
|
||||
@@ -34,16 +34,21 @@ from backend.data.model import (
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -109,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -138,6 +144,19 @@ class CredentialsMetaResponse(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
return CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
@@ -204,34 +223,20 @@ async def callback(
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -242,19 +247,11 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -267,18 +264,21 @@ async def get_credential(
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if credential.provider != provider:
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
return credential
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@@ -288,16 +288,22 @@ async def create_credentials(
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(credentials.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot create credentials with a reserved ID",
|
||||
)
|
||||
credentials.provider = provider
|
||||
try:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
return credentials
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
@@ -332,15 +338,29 @@ async def delete_credentials(
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,570 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import router
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "test-user-id"
|
||||
|
||||
|
||||
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My API Key",
|
||||
api_key=SecretStr("sk-secret-key-value"),
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My OAuth",
|
||||
access_token=SecretStr("ghp_secret_token"),
|
||||
refresh_token=SecretStr("ghp_refresh_secret"),
|
||||
scopes=["repo", "user"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
|
||||
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My Login",
|
||||
username=SecretStr("admin"),
|
||||
password=SecretStr("s3cret-pass"),
|
||||
)
|
||||
|
||||
|
||||
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="Host Cred",
|
||||
host="https://api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer top-secret")},
|
||||
)
|
||||
|
||||
|
||||
def _make_sdk_default_cred(provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
title=f"{provider} (default)",
|
||||
api_key=SecretStr("sk-platform-secret-key"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestGetCredentialReturnsMetaOnly:
|
||||
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
|
||||
|
||||
def test_api_key_credential_no_secret(self):
|
||||
cred = _make_api_key_cred()
|
||||
with (
|
||||
patch.object(router, "dependencies", []),
|
||||
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
|
||||
):
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-123"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["type"] == "api_key"
|
||||
assert "api_key" not in data
|
||||
assert "sk-secret-key-value" not in str(data)
|
||||
|
||||
def test_oauth2_credential_no_secret(self):
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-456"
|
||||
assert data["scopes"] == ["repo", "user"]
|
||||
assert data["username"] == "testuser"
|
||||
assert "access_token" not in data
|
||||
assert "refresh_token" not in data
|
||||
assert "ghp_" not in str(data)
|
||||
|
||||
def test_user_password_credential_no_secret(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-789")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-789"
|
||||
assert "password" not in data
|
||||
assert "username" not in data or data["username"] is None
|
||||
assert "s3cret-pass" not in str(data)
|
||||
assert "admin" not in str(data)
|
||||
|
||||
def test_host_scoped_credential_no_secret(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-host")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-host"
|
||||
assert data["host"] == "https://api.example.com"
|
||||
assert "headers" not in data
|
||||
assert "top-secret" not in str(data)
|
||||
|
||||
def test_get_credential_wrong_provider_returns_404(self):
|
||||
"""Provider mismatch should return generic 404, not leak credential existence."""
|
||||
cred = _make_api_key_cred(provider="openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_list_credentials_no_secrets(self):
|
||||
"""List endpoint must not leak secrets in any credential."""
|
||||
creds = [_make_api_key_cred(), _make_oauth2_cred()]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
raw = str(resp.json())
|
||||
assert "sk-secret-key-value" not in raw
|
||||
assert "ghp_secret_token" not in raw
|
||||
assert "ghp_refresh_secret" not in raw
|
||||
|
||||
|
||||
class TestSdkDefaultCredentialsNotAccessible:
|
||||
"""SDK default credentials (ID ending in '-default') must be hidden."""
|
||||
|
||||
def test_get_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.get("/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_list_credentials_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_list_by_provider_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[user_cred, sdk_cred]
|
||||
)
|
||||
resp = client.get("/openai/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_delete_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock()
|
||||
resp = client.request("DELETE", "/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.store.get_creds_by_id.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateCredentialNoSecretInResponse:
|
||||
"""POST /{provider}/credentials must not return secrets."""
|
||||
|
||||
def test_create_api_key_no_secret_in_response(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "new-cred",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "New Key",
|
||||
"api_key": "sk-newsecret",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-cred"
|
||||
assert "api_key" not in data
|
||||
assert "sk-newsecret" not in str(data)
|
||||
|
||||
def test_create_with_sdk_default_id_rejected(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "openai-default",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "Sneaky",
|
||||
"api_key": "sk-evil",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Shared logic for adding store agents to a user's library.
|
||||
|
||||
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
|
||||
delegate to these helpers so the duplication-prone create/restore/dedup
|
||||
logic lives in exactly one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_graph_for_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
admin: bool,
|
||||
) -> GraphModel:
|
||||
"""Look up a StoreListingVersion and resolve its graph.
|
||||
|
||||
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
|
||||
APPROVED-only check. Otherwise uses the regular ``get_graph``.
|
||||
"""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
if not slv or not slv.AgentGraph:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
ag = slv.AgentGraph
|
||||
if admin:
|
||||
graph_model = await graph_db.get_graph_as_admin(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph_model:
|
||||
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
|
||||
return graph_model
|
||||
|
||||
|
||||
async def add_graph_to_library(
|
||||
store_listing_version_id: str,
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent.
|
||||
|
||||
Uses a create-then-catch-UniqueViolationError-then-update pattern on
|
||||
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
|
||||
This is more robust than ``upsert`` because Prisma's upsert atomicity
|
||||
guarantees are not well-documented for all versions.
|
||||
"""
|
||||
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
|
||||
_include = library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_model.id,
|
||||
"version": graph_model.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Already exists — update to restore if previously soft-deleted/archived
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().update(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
if added_agent is None:
|
||||
raise NotFoundError(
|
||||
f"LibraryAgent for graph #{graph_model.id} "
|
||||
f"v{graph_model.version} not found after UniqueViolationError"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.errors
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"""When no matching LibraryAgent exists, create inserts a new one."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
created_agent = MagicMock(name="CreatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
assert create_data["User"] == {"connect": {"id": "user-id"}}
|
||||
assert create_data["AgentGraph"] == {
|
||||
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
|
||||
}
|
||||
assert create_data["isCreatedByUser"] is False
|
||||
assert create_data["useGraphIsActiveVersion"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"""UniqueViolationError on create falls back to update."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
updated_agent = MagicMock(name="UpdatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
MagicMock(), message="unique constraint"
|
||||
)
|
||||
)
|
||||
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "user-id",
|
||||
"agentGraphId": "graph-id",
|
||||
"agentGraphVersion": 2,
|
||||
}
|
||||
}
|
||||
update_data = update_call.kwargs["data"]
|
||||
assert update_data["isDeleted"] is False
|
||||
assert update_data["isArchived"] is False
|
||||
@@ -336,12 +336,15 @@ async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
include_archived: bool = False,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if not include_archived:
|
||||
filter["isArchived"] = False
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
@@ -433,32 +436,58 @@ async def create_library_agent(
|
||||
async with transaction() as tx:
|
||||
library_agents = await asyncio.gather(
|
||||
*(
|
||||
prisma.models.LibraryAgent.prisma(tx).create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
prisma.models.LibraryAgent.prisma(tx).upsert(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_entry.id,
|
||||
"agentGraphVersion": graph_entry.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
"update": {
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"useGraphIsActiveVersion": True,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
@@ -582,7 +611,9 @@ async def update_graph_in_library(
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
library_agent = await get_library_agent_by_graph_id(
|
||||
user_id, created_graph.id, include_archived=True
|
||||
)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
@@ -818,92 +849,38 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Adds a marketplace agent to the user’s library.
|
||||
|
||||
See also: `add_store_agent_to_library_as_admin()` which uses
|
||||
`get_graph_as_admin` to bypass marketplace status checks for admin review.
|
||||
"""
|
||||
Adds an agent from a store listing version to the user's library if they don't already have it.
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
Args:
|
||||
store_listing_version_id: The ID of the store listing version containing the agent.
|
||||
user_id: The user’s library to which the agent is being added.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=False
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
|
||||
# Convert to GraphModel to check for HITL blocks
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=graph.id,
|
||||
version=graph.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
async def add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
|
||||
APPROVED-only checks, allowing admins to add pending agents for review."""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
logger.warning(
|
||||
f"ADMIN adding agent from store listing version "
|
||||
f"#{store_listing_version_id} to library for user #{user_id}"
|
||||
)
|
||||
if not graph_model:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
# Check if user already has this agent (non-deleted)
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph.id, graph.version
|
||||
):
|
||||
return existing
|
||||
|
||||
# Check for soft-deleted version and restore it
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
},
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=True
|
||||
)
|
||||
if deleted_agent and deleted_agent.isDeleted:
|
||||
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -85,10 +87,6 @@ async def test_get_library_agents(mocker):
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
mock_store_listing_data = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
@@ -143,15 +141,18 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
)
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.id = "agent1"
|
||||
mock_graph_model.version = 1
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
) # Empty list so _has_human_in_the_loop_blocks returns False
|
||||
@@ -170,37 +171,27 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
assert create_call_args is not None
|
||||
|
||||
# Verify the main structure
|
||||
expected_data = {
|
||||
# Verify the create data structure
|
||||
create_data = create_call_args.kwargs["data"]
|
||||
expected_create = {
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
}
|
||||
|
||||
actual_data = create_call_args[1]["data"]
|
||||
# Check that all expected fields are present
|
||||
for key, value in expected_data.items():
|
||||
assert actual_data[key] == value
|
||||
for key, value in expected_create.items():
|
||||
assert create_data[key] == value
|
||||
|
||||
# Check that settings field is present and is a SafeJson object
|
||||
assert "settings" in actual_data
|
||||
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
|
||||
assert "settings" in create_data
|
||||
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
|
||||
|
||||
# Check include parameter
|
||||
assert create_call_args[1]["include"] == library_agent_include(
|
||||
assert create_call_args.kwargs["include"] == library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
@@ -224,3 +215,141 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id(
|
||||
"test-user",
|
||||
"agent1",
|
||||
7,
|
||||
include_archived=True,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
graph = mocker.Mock(id="graph-id")
|
||||
existing_version = mocker.Mock(version=1, is_active=True)
|
||||
graph_model = mocker.Mock()
|
||||
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
|
||||
current_library_agent = mocker.Mock()
|
||||
updated_library_agent = mocker.Mock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.get_graph_all_versions",
|
||||
new=mocker.AsyncMock(return_value=[existing_version]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.make_graph_model",
|
||||
return_value=graph_model,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.create_graph",
|
||||
new=mocker.AsyncMock(return_value=created_graph),
|
||||
)
|
||||
mock_get_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.get_library_agent_by_graph_id",
|
||||
new=mocker.AsyncMock(return_value=current_library_agent),
|
||||
)
|
||||
mock_update_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.update_library_agent_version_and_settings",
|
||||
new=mocker.AsyncMock(return_value=updated_library_agent),
|
||||
)
|
||||
|
||||
result_graph, result_library_agent = await db.update_graph_in_library(
|
||||
graph,
|
||||
"test-user",
|
||||
)
|
||||
|
||||
assert result_graph is created_graph
|
||||
assert result_library_agent is updated_library_agent
|
||||
assert graph.version == 2
|
||||
graph_model.reassign_ids.assert_called_once_with(
|
||||
user_id="test-user", reassign_graph_id=False
|
||||
)
|
||||
mock_get_library_agent.assert_awaited_once_with(
|
||||
"test-user",
|
||||
"graph-id",
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_library_agent_uses_upsert():
|
||||
"""create_library_agent should use upsert (not create) to handle duplicates."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-1"
|
||||
mock_graph.version = 1
|
||||
mock_graph.user_id = "user-1"
|
||||
mock_graph.nodes = []
|
||||
mock_graph.sub_graphs = []
|
||||
|
||||
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_tx():
|
||||
yield None
|
||||
|
||||
with (
|
||||
patch("backend.api.features.library.db.transaction", fake_tx),
|
||||
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library.db.add_generated_agent_image",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
|
||||
|
||||
result = await db.create_library_agent(mock_graph, "user-1")
|
||||
|
||||
assert len(result) == 1
|
||||
upsert_call = mock_prisma.return_value.upsert.call_args
|
||||
assert upsert_call is not None
|
||||
# Verify the upsert where clause uses the composite unique key
|
||||
where = upsert_call.kwargs["where"]
|
||||
assert "userId_agentGraphId_agentGraphVersion" in where
|
||||
# Verify the upsert data has both create and update branches
|
||||
data = upsert_call.kwargs["data"]
|
||||
assert "create" in data
|
||||
assert "update" in data
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from backend.api.features.v1 import v1_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(v1_router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_onboarding_profile_success(mocker):
|
||||
mock_extract = mocker.patch(
|
||||
"backend.api.features.v1.extract_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
mock_upsert = mocker.patch(
|
||||
"backend.api.features.v1.upsert_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
mock_extract.return_value = BusinessUnderstandingInput.model_construct(
|
||||
user_name="John",
|
||||
user_role="Founder/CEO",
|
||||
pain_points=["Finding leads"],
|
||||
suggested_prompts={"Learn": ["How do I automate lead gen?"]},
|
||||
)
|
||||
mock_upsert.return_value = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={
|
||||
"user_name": "John",
|
||||
"user_role": "Founder/CEO",
|
||||
"pain_points": ["Finding leads", "Email & outreach"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_extract.assert_awaited_once()
|
||||
mock_upsert.assert_awaited_once()
|
||||
|
||||
|
||||
def test_onboarding_profile_missing_fields():
|
||||
response = client.post(
|
||||
"/onboarding/profile",
|
||||
json={"user_name": "John"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -391,6 +391,11 @@ async def get_available_graph(
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details from the StoreAgent view (APPROVED agents only).
|
||||
|
||||
See also: `get_store_agent_details_as_admin()` which bypasses the
|
||||
APPROVED-only StoreAgent view for admin preview of pending submissions.
|
||||
"""
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -411,6 +416,57 @@ async def get_store_agent_by_version_id(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_agent_details_as_admin(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details for admin preview, bypassing the APPROVED-only
|
||||
StoreAgent view. Queries StoreListingVersion directly so pending
|
||||
submissions are visible."""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": {"include": {"CreatorProfile": True}},
|
||||
},
|
||||
)
|
||||
if not slv or not slv.StoreListing:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
listing = slv.StoreListing
|
||||
# CreatorProfile is a required FK relation — should always exist.
|
||||
# If it's None, the DB is in a bad state.
|
||||
profile = listing.CreatorProfile
|
||||
if not profile:
|
||||
raise DatabaseError(
|
||||
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
|
||||
)
|
||||
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=slv.id,
|
||||
slug=listing.slug,
|
||||
agent_name=slv.name,
|
||||
agent_video=slv.videoUrl or "",
|
||||
agent_output_demo=slv.agentOutputDemoUrl or "",
|
||||
agent_image=slv.imageUrls,
|
||||
creator=profile.username,
|
||||
creator_avatar=profile.avatarUrl or "",
|
||||
sub_heading=slv.subHeading,
|
||||
description=slv.description,
|
||||
instructions=slv.instructions,
|
||||
categories=slv.categories,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
versions=[str(slv.version)],
|
||||
graph_id=slv.agentGraphId,
|
||||
graph_versions=[str(slv.agentGraphVersion)],
|
||||
last_updated=slv.updatedAt,
|
||||
recommended_schedule_cron=slv.recommendedScheduleCron,
|
||||
active_version_id=listing.activeVersionId or slv.id,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
)
|
||||
|
||||
|
||||
class StoreCreatorsSortOptions(Enum):
|
||||
# NOTE: values correspond 1:1 to columns of the Creator view
|
||||
AGENT_RATING = "agent_rating"
|
||||
|
||||
@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -63,12 +63,17 @@ from backend.data.onboarding import (
|
||||
UserOnboardingUpdate,
|
||||
complete_onboarding_step,
|
||||
complete_re_run_agent,
|
||||
format_onboarding_for_extraction,
|
||||
get_recommended_agents,
|
||||
get_user_onboarding,
|
||||
onboarding_enabled,
|
||||
reset_user_onboarding,
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.tally import extract_business_understanding
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
@@ -282,35 +287,33 @@ async def get_onboarding_agents(
|
||||
return await get_recommended_agents(user_id)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding status check."""
|
||||
class OnboardingProfileRequest(pydantic.BaseModel):
|
||||
"""Request body for onboarding profile submission."""
|
||||
|
||||
is_onboarding_enabled: bool
|
||||
is_chat_enabled: bool
|
||||
user_name: str = pydantic.Field(min_length=1, max_length=100)
|
||||
user_role: str = pydantic.Field(min_length=1, max_length=100)
|
||||
pain_points: list[str] = pydantic.Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||
"""Response for onboarding completion check."""
|
||||
|
||||
is_completed: bool
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
"/onboarding/completed",
|
||||
summary="Check if onboarding is completed",
|
||||
tags=["onboarding", "public"],
|
||||
response_model=OnboardingStatusResponse,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_enabled(
|
||||
async def is_onboarding_completed(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> OnboardingStatusResponse:
|
||||
# Check if chat is enabled for user
|
||||
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||
|
||||
# If chat is enabled, skip legacy onboarding
|
||||
if is_chat_enabled:
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=False,
|
||||
is_chat_enabled=True,
|
||||
)
|
||||
|
||||
user_onboarding = await get_user_onboarding(user_id)
|
||||
return OnboardingStatusResponse(
|
||||
is_onboarding_enabled=await onboarding_enabled(),
|
||||
is_chat_enabled=False,
|
||||
is_completed=OnboardingStep.VISIT_COPILOT in user_onboarding.completedSteps,
|
||||
)
|
||||
|
||||
|
||||
@@ -325,6 +328,38 @@ async def reset_onboarding(user_id: Annotated[str, Security(get_user_id)]):
|
||||
return await reset_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/onboarding/profile",
|
||||
summary="Submit onboarding profile",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def submit_onboarding_profile(
|
||||
data: OnboardingProfileRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
):
|
||||
formatted = format_onboarding_for_extraction(
|
||||
user_name=data.user_name,
|
||||
user_role=data.user_role,
|
||||
pain_points=data.pain_points,
|
||||
)
|
||||
|
||||
try:
|
||||
understanding_input = await extract_business_understanding(formatted)
|
||||
except Exception:
|
||||
understanding_input = BusinessUnderstandingInput.model_construct()
|
||||
|
||||
# Ensure the direct fields are set even if LLM missed them
|
||||
understanding_input.user_name = data.user_name
|
||||
understanding_input.user_role = data.user_role
|
||||
if not understanding_input.pain_points:
|
||||
understanding_input.pain_points = data.pain_points
|
||||
|
||||
await upsert_business_understanding(user_id, understanding_input)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Blocks ###########################
|
||||
########################################################
|
||||
@@ -980,14 +1015,16 @@ async def execute_graph(
|
||||
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
dry_run: Annotated[bool, Body(embed=True)] = False,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
if not dry_run:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
@@ -997,6 +1034,7 @@ async def execute_graph(
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
|
||||
@@ -12,7 +12,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
@@ -131,9 +131,26 @@ class StorageUsageResponse(BaseModel):
|
||||
file_count: int
|
||||
|
||||
|
||||
class WorkspaceFileItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
files: list[WorkspaceFileItem]
|
||||
offset: int = 0
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
operation_id="getWorkspaceDownloadFileById",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -158,6 +175,7 @@ async def download_file(
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
operation_id="deleteWorkspaceFile",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -183,6 +201,7 @@ async def delete_workspace_file(
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
operation_id="uploadWorkspaceFile",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -196,6 +215,9 @@ async def upload_file(
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
# Empty-string session_id drops session scoping; normalize to None.
|
||||
session_id = session_id or None
|
||||
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
@@ -250,16 +272,27 @@ async def upload_file(
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
content, filename, overwrite=overwrite, metadata={"origin": "user-upload"}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
# write_file raises ValueError for both path-conflict and size-limit
|
||||
# cases; map each to its correct HTTP status.
|
||||
message = str(e)
|
||||
if message.startswith("File too large"):
|
||||
raise fastapi.HTTPException(status_code=413, detail=message) from e
|
||||
raise fastapi.HTTPException(status_code=409, detail=message) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
try:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to soft-delete over-quota file {workspace_file.id} "
|
||||
f"in workspace {workspace.id}: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
@@ -281,6 +314,7 @@ async def upload_file(
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
operation_id="getWorkspaceStorageUsage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
@@ -301,3 +335,57 @@ async def get_storage_usage(
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files",
|
||||
summary="List workspace files",
|
||||
operation_id="listWorkspaceFiles",
|
||||
)
|
||||
async def list_workspace_files(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
session_id: str | None = Query(default=None),
|
||||
limit: int = Query(default=200, ge=1, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListFilesResponse:
|
||||
"""
|
||||
List files in the user's workspace.
|
||||
|
||||
When session_id is provided, only files for that session are returned.
|
||||
Otherwise, all files across sessions are listed. Results are paginated
|
||||
via `limit`/`offset`; `has_more` indicates whether additional pages exist.
|
||||
"""
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Treat empty-string session_id the same as omitted — an empty value
|
||||
# would otherwise silently list files across every session instead of
|
||||
# scoping to one.
|
||||
session_id = session_id or None
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
include_all = session_id is None
|
||||
# Fetch one extra to compute has_more without a separate count query.
|
||||
files = await manager.list_files(
|
||||
limit=limit + 1,
|
||||
offset=offset,
|
||||
include_all_sessions=include_all,
|
||||
)
|
||||
has_more = len(files) > limit
|
||||
page = files[:limit]
|
||||
|
||||
return ListFilesResponse(
|
||||
files=[
|
||||
WorkspaceFileItem(
|
||||
id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
metadata=f.metadata or {},
|
||||
created_at=f.created_at.isoformat(),
|
||||
)
|
||||
for f in page
|
||||
],
|
||||
offset=offset,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
@@ -1,48 +1,28 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
from backend.api.features.workspace.routes import router
|
||||
from backend.data.workspace import Workspace, WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
"""Mirror the production ValueError → 400 mapping from the REST app."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
@@ -53,25 +33,201 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_workspace(user_id: str = "test-user-id") -> Workspace:
|
||||
return Workspace(
|
||||
id="ws-001",
|
||||
user_id=user_id,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_file(**overrides) -> WorkspaceFile:
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"workspace_id": "ws-001",
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"storage_path": "local://test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"checksum": None,
|
||||
"is_deleted": False,
|
||||
"deleted_at": None,
|
||||
"metadata": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkspaceFile(**defaults)
|
||||
|
||||
|
||||
def _make_file_mock(**overrides) -> MagicMock:
|
||||
"""Create a mock WorkspaceFile to simulate DB records with null fields."""
|
||||
defaults = {
|
||||
"id": "file-001",
|
||||
"name": "test.txt",
|
||||
"path": "/test.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size_bytes": 100,
|
||||
"metadata": {},
|
||||
"created_at": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock(spec=WorkspaceFile)
|
||||
for k, v in defaults.items():
|
||||
setattr(mock, k, v)
|
||||
return mock
|
||||
|
||||
|
||||
# -- list_workspace_files tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_returns_all_when_no_session(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt", metadata={"origin": "user-upload"}),
|
||||
_make_file(id="f2", name="b.csv", metadata={"origin": "agent-created"}),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["files"]) == 2
|
||||
assert data["has_more"] is False
|
||||
assert data["offset"] == 0
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][0]["metadata"] == {"origin": "user-upload"}
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_scopes_to_session_when_provided(
|
||||
mock_manager_cls, mock_get_workspace, test_user_id
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace(user_id=test_user_id)
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?session_id=sess-123")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["files"] == []
|
||||
assert data["has_more"] is False
|
||||
mock_manager_cls.assert_called_once_with(test_user_id, "ws-001", "sess-123")
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=201, offset=0, include_all_sessions=False
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_null_metadata_coerced_to_empty_dict(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Route uses `f.metadata or {}` for pre-existing files with null metadata."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = [_make_file_mock(metadata=None)]
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["files"][0]["metadata"] == {}
|
||||
|
||||
|
||||
# -- upload_file metadata tests --
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_passes_user_upload_origin_metadata(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
written = _make_file(id="new-file", name="doc.pdf")
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.return_value = written
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("doc.pdf", b"fake-pdf-content", "application/pdf")},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_instance.write_file.assert_called_once()
|
||||
call_kwargs = mock_instance.write_file.call_args
|
||||
assert call_kwargs.kwargs.get("metadata") == {"origin": "user-upload"}
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.get_workspace_total_size")
|
||||
@patch("backend.api.features.workspace.routes.scan_content_safe")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_upload_returns_409_on_file_conflict(
|
||||
mock_manager_cls, mock_scan, mock_total_size, mock_get_workspace
|
||||
):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_total_size.return_value = 100
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.write_file.side_effect = ValueError("File already exists at path")
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.post(
|
||||
"/files/upload",
|
||||
files={"file": ("dup.txt", b"content", "text/plain")},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
# -- Restored upload/download/delete security + invariant tests --
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
_MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-001",
|
||||
created_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
name="hello.txt",
|
||||
path="/sessions/sess-1/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_happy_path(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -82,7 +238,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -96,10 +252,7 @@ def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_exceeds_max_file_size(mocker):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
@@ -109,15 +262,11 @@ def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_storage_quota_exceeded(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
@@ -128,27 +277,22 @@ def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
def test_upload_post_write_quota_race(mocker):
|
||||
"""Concurrent upload tipping over limit after write should soft-delete + 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
side_effect=[0, 600 * 1024 * 1024],
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -160,17 +304,14 @@ def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-001")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_any_extension(mocker):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -181,7 +322,7 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -191,16 +332,13 @@ def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_blocked_by_virus_scan(mocker):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -211,7 +349,7 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -219,18 +357,14 @@ def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_file_without_extension(mocker):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -241,7 +375,7 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
@@ -257,14 +391,11 @@ def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
def test_upload_strips_path_components(mocker):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
@@ -275,28 +406,23 @@ def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=_MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_download_file_not_found(mocker):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
@@ -307,14 +433,11 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_success(mocker):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
@@ -329,11 +452,11 @@ def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_not_found(mocker):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
@@ -347,7 +470,7 @@ def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_delete_file_no_workspace(mocker):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
@@ -357,3 +480,123 @@ def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_too_large_returns_413(mocker):
|
||||
"""write_file raises ValueError("File too large: …") → must map to 413."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File too large: 900 bytes exceeds 1MB limit")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "File too large" in response.text
|
||||
|
||||
|
||||
def test_upload_write_file_conflict_returns_409(mocker):
|
||||
"""Non-'File too large' ValueErrors from write_file stay as 409."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=_make_workspace(),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(
|
||||
side_effect=ValueError("File already exists at path: /sessions/x/a.txt")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.text
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_true_when_limit_exceeded(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""The limit+1 fetch trick must flip has_more=True and trim the page."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
# Backend was asked for limit+1=3, and returned exactly 3 items.
|
||||
files = [
|
||||
_make_file(id="f1", name="a.txt"),
|
||||
_make_file(id="f2", name="b.txt"),
|
||||
_make_file(id="f3", name="c.txt"),
|
||||
]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is True
|
||||
assert len(data["files"]) == 2
|
||||
assert data["files"][0]["id"] == "f1"
|
||||
assert data["files"][1]["id"] == "f2"
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=3, offset=0, include_all_sessions=True
|
||||
)
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_has_more_false_when_exactly_page_size(
|
||||
mock_manager_cls, mock_get_workspace
|
||||
):
|
||||
"""Exactly `limit` rows means we're on the last page — has_more=False."""
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
files = [_make_file(id="f1", name="a.txt"), _make_file(id="f2", name="b.txt")]
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = files
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?limit=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
assert len(data["files"]) == 2
|
||||
|
||||
|
||||
@patch("backend.api.features.workspace.routes.get_or_create_workspace")
|
||||
@patch("backend.api.features.workspace.routes.WorkspaceManager")
|
||||
def test_list_files_offset_is_echoed_back(mock_manager_cls, mock_get_workspace):
|
||||
mock_get_workspace.return_value = _make_workspace()
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.list_files.return_value = []
|
||||
mock_manager_cls.return_value = mock_instance
|
||||
|
||||
response = client.get("/files?offset=50&limit=10")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["offset"] == 50
|
||||
mock_instance.list_files.assert_called_once_with(
|
||||
limit=11, offset=50, include_all_sessions=True
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -117,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -318,6 +324,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.rate_limit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
@@ -528,8 +539,11 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
):
|
||||
from backend.api.features.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
|
||||
@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
|
||||
# A true/false answer fits comfortably within this budget.
|
||||
MIN_LLM_OUTPUT_TOKENS = 16
|
||||
|
||||
|
||||
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
|
||||
"""Parse an LLM response into a boolean result.
|
||||
|
||||
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
|
||||
response is unambiguous; otherwise it contains a diagnostic message
|
||||
and *result* defaults to ``False``.
|
||||
"""
|
||||
text = response_text.strip().lower()
|
||||
if text == "true":
|
||||
return True, None
|
||||
if text == "false":
|
||||
return False, None
|
||||
|
||||
# Fuzzy match – use word boundaries to avoid false positives like "untrue".
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
return True, None
|
||||
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
return False, None
|
||||
|
||||
return False, f"Unclear AI response: '{response_text}'"
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=MIN_LLM_OUTPUT_TOKENS,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
result, error = _parse_boolean_response(response.response)
|
||||
if error:
|
||||
yield "error", error
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for AIConditionBlock – regression coverage for max_tokens and error propagation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.ai_condition import (
|
||||
MIN_LLM_OUTPUT_TOKENS,
|
||||
AIConditionBlock,
|
||||
_parse_boolean_response,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AICredentials,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to collect all yields from the async generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
|
||||
outputs: dict[str, object] = {}
|
||||
async for name, value in block.run(input_data, credentials=credentials):
|
||||
outputs[name] = value
|
||||
return outputs
|
||||
|
||||
|
||||
def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
defaults: dict = {
|
||||
"input_value": "hello@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "yes!",
|
||||
"no_value": "no!",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response=response_text,
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_boolean_response unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBooleanResponse:
|
||||
def test_true_exact(self):
|
||||
assert _parse_boolean_response("true") == (True, None)
|
||||
|
||||
def test_false_exact(self):
|
||||
assert _parse_boolean_response("false") == (False, None)
|
||||
|
||||
def test_true_with_whitespace(self):
|
||||
assert _parse_boolean_response(" True ") == (True, None)
|
||||
|
||||
def test_yes_fuzzy(self):
|
||||
assert _parse_boolean_response("Yes") == (True, None)
|
||||
|
||||
def test_no_fuzzy(self):
|
||||
assert _parse_boolean_response("no") == (False, None)
|
||||
|
||||
def test_one_fuzzy(self):
|
||||
assert _parse_boolean_response("1") == (True, None)
|
||||
|
||||
def test_zero_fuzzy(self):
|
||||
assert _parse_boolean_response("0") == (False, None)
|
||||
|
||||
def test_unclear_response(self):
|
||||
result, error = _parse_boolean_response("I'm not sure")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
assert "Unclear" in error
|
||||
|
||||
def test_conflicting_tokens(self):
|
||||
result, error = _parse_boolean_response("true and false")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxTokensRegression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_receives_min_output_tokens(self):
|
||||
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) – the previous value
|
||||
of 1 was too low and caused OpenAI to reject the request."""
|
||||
block = AIConditionBlock()
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def spy_llm_call(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_llm_response("true")
|
||||
|
||||
block.llm_call = spy_llm_call # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
|
||||
assert captured_kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: exceptions from llm_call must propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_exception_propagates(self):
|
||||
"""If llm_call raises, the exception must NOT be swallowed.
|
||||
Previously the block caught all exceptions and silently returned
|
||||
result=False."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def boom(**kwargs):
|
||||
raise RuntimeError("LLM provider error")
|
||||
|
||||
block.llm_call = boom # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
@@ -15,6 +15,12 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -96,6 +102,65 @@ class AutoPilotBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
@@ -182,11 +247,11 @@ class AutoPilotBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str) -> str:
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
@@ -196,6 +261,7 @@ class AutoPilotBlock(Block):
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
@@ -209,14 +275,21 @@ class AutoPilotBlock(Block):
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import collect_copilot_response
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
@@ -225,6 +298,7 @@ class AutoPilotBlock(Block):
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
@@ -271,6 +345,8 @@ class AutoPilotBlock(Block):
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -295,11 +371,20 @@ class AutoPilotBlock(Block):
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(execution_context.user_id)
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id, dry_run=input_data.dry_run
|
||||
)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
@@ -312,6 +397,7 @@ class AutoPilotBlock(Block):
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
@@ -374,3 +460,78 @@ def _reset_recursion(
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
|
||||
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
|
||||
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
|
||||
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
|
||||
description="Reads messages from a Discord channel using a bot token.",
|
||||
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={
|
||||
"continuous_read": False,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -28,9 +30,9 @@ class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
It takes in a value, name, and description.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
It outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -47,12 +49,6 @@ class AgentInputBlock(Block):
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
@@ -65,10 +61,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
return copy.deepcopy(self.get_field_schema("value"))
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field for interface definition
|
||||
@@ -86,18 +79,16 @@ class AgentInputBlock(Block):
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"value": 42,
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"description": "Example numeric input.",
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
("result", 42),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
@@ -245,13 +236,11 @@ class AgentShortTextInputBlock(AgentInputBlock):
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -285,13 +274,11 @@ class AgentLongTextInputBlock(AgentInputBlock):
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -325,13 +312,11 @@ class AgentNumberInputBlock(AgentInputBlock):
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -484,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -494,13 +480,26 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.options:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
@@ -515,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -49,6 +49,9 @@ settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
# HTTP status codes for user-caused errors that should not be reported to Sentry.
|
||||
USER_ERROR_STATUS_CODES = (401, 403, 429)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -101,6 +104,18 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
if isinstance(value, str) and "/" in value:
|
||||
stripped = value.split("/", 1)[1]
|
||||
try:
|
||||
return cls(stripped)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
@@ -190,6 +205,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Z.ai (Zhipu) models
|
||||
ZAI_GLM_4_32B = "z-ai/glm-4-32b"
|
||||
ZAI_GLM_4_5 = "z-ai/glm-4.5"
|
||||
ZAI_GLM_4_5_AIR = "z-ai/glm-4.5-air"
|
||||
ZAI_GLM_4_5_AIR_FREE = "z-ai/glm-4.5-air:free"
|
||||
ZAI_GLM_4_5V = "z-ai/glm-4.5v"
|
||||
ZAI_GLM_4_6 = "z-ai/glm-4.6"
|
||||
ZAI_GLM_4_6V = "z-ai/glm-4.6v"
|
||||
ZAI_GLM_4_7 = "z-ai/glm-4.7"
|
||||
ZAI_GLM_4_7_FLASH = "z-ai/glm-4.7-flash"
|
||||
ZAI_GLM_5 = "z-ai/glm-5"
|
||||
ZAI_GLM_5_TURBO = "z-ai/glm-5-turbo"
|
||||
ZAI_GLM_5V_TURBO = "z-ai/glm-5v-turbo"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
@@ -615,6 +643,43 @@ MODEL_METADATA = {
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# https://openrouter.ai/models?q=z-ai
|
||||
LlmModel.ZAI_GLM_4_32B: ModelMetadata(
|
||||
"open_router", 128000, 128000, "GLM 4 32B", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR: ModelMetadata(
|
||||
"open_router", 131072, 98304, "GLM 4.5 Air", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5_AIR_FREE: ModelMetadata(
|
||||
"open_router", 131072, 96000, "GLM 4.5 Air (Free)", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_5V: ModelMetadata(
|
||||
"open_router", 65536, 16384, "GLM 4.5V", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6: ModelMetadata(
|
||||
"open_router", 204800, 204800, "GLM 4.6", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_6V: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GLM 4.6V", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7: ModelMetadata(
|
||||
"open_router", 202752, 65535, "GLM 4.7", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_4_7_FLASH: ModelMetadata(
|
||||
"open_router", 202752, 202752, "GLM 4.7 Flash", "OpenRouter", "Z.ai", 1
|
||||
),
|
||||
LlmModel.ZAI_GLM_5: ModelMetadata(
|
||||
"open_router", 80000, 80000, "GLM 5", "OpenRouter", "Z.ai", 2
|
||||
),
|
||||
LlmModel.ZAI_GLM_5_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5 Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
LlmModel.ZAI_GLM_5V_TURBO: ModelMetadata(
|
||||
"open_router", 202752, 131072, "GLM 5V Turbo", "OpenRouter", "Z.ai", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
@@ -709,6 +774,9 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -724,6 +792,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -891,65 +962,60 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
@@ -962,6 +1028,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1021,12 +1089,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1063,12 +1127,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1098,6 +1158,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1134,6 +1196,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
@@ -1462,7 +1527,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
is_user_error = (
|
||||
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
|
||||
and e.status_code in USER_ERROR_STATUS_CODES
|
||||
)
|
||||
if is_user_error:
|
||||
logger.warning(f"Error calling LLM: {e}")
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
break
|
||||
else:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
@@ -1992,6 +2066,19 @@ class AIConversationBlock(AIBlockBase):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
has_messages = any(
|
||||
isinstance(m, dict)
|
||||
and isinstance(m.get("content"), str)
|
||||
and bool(m["content"].strip())
|
||||
for m in (input_data.messages or [])
|
||||
)
|
||||
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
|
||||
if not has_messages and not has_prompt:
|
||||
raise ValueError(
|
||||
"Cannot call LLM with no messages and no prompt. "
|
||||
"Provide at least one message or a non-empty prompt."
|
||||
)
|
||||
|
||||
response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
|
||||
@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
tool_description: str = SchemaField(
|
||||
description="Description of the selected MCP tool. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
323
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.sql_query_helpers import (
|
||||
_DATABASE_TYPE_DEFAULT_PORT,
|
||||
_DATABASE_TYPE_TO_DRIVER,
|
||||
DatabaseType,
|
||||
_execute_query,
|
||||
_sanitize_error,
|
||||
_validate_query_is_read_only,
|
||||
_validate_single_statement,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="database",
|
||||
username=SecretStr("test_user"),
|
||||
password=SecretStr("test_pass"),
|
||||
title="Mock Database credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
DatabaseCredentials = UserPasswordCredentials
|
||||
DatabaseCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DATABASE],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="Database username and password",
|
||||
)
|
||||
|
||||
|
||||
class SQLQueryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
database_type: DatabaseType = SchemaField(
|
||||
default=DatabaseType.POSTGRES,
|
||||
description="Database engine",
|
||||
advanced=False,
|
||||
)
|
||||
host: SecretStr = SchemaField(
|
||||
description=(
|
||||
"Database hostname or IP address. "
|
||||
"Treated as a secret to avoid leaking infrastructure details. "
|
||||
"Private/internal IPs are blocked (SSRF protection)."
|
||||
),
|
||||
placeholder="db.example.com",
|
||||
secret=True,
|
||||
)
|
||||
port: int | None = SchemaField(
|
||||
default=None,
|
||||
description=(
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
placeholder="my_database",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="SQL query to execute",
|
||||
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
|
||||
)
|
||||
read_only: bool = SchemaField(
|
||||
default=True,
|
||||
description=(
|
||||
"When enabled (default), only SELECT queries are allowed "
|
||||
"and the database session is set to read-only mode. "
|
||||
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
|
||||
),
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=30,
|
||||
description="Query timeout in seconds (max 120)",
|
||||
ge=1,
|
||||
le=120,
|
||||
)
|
||||
max_rows: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of rows to return (max 10000)",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="Query results as a list of row dictionaries"
|
||||
)
|
||||
columns: list[str] = SchemaField(
|
||||
description="Column names from the query result"
|
||||
)
|
||||
row_count: int = SchemaField(description="Number of rows returned")
|
||||
truncated: bool = SchemaField(
|
||||
description=(
|
||||
"True when the result set was capped by max_rows, "
|
||||
"indicating additional rows exist in the database"
|
||||
)
|
||||
)
|
||||
affected_rows: int = SchemaField(
|
||||
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the query failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
|
||||
description=(
|
||||
"Execute a SQL query. Read-only by default for safety "
|
||||
"-- disable to allow write operations. "
|
||||
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
|
||||
),
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=SQLQueryBlock.Input,
|
||||
output_schema=SQLQueryBlock.Output,
|
||||
test_input={
|
||||
"query": "SELECT 1 AS test_col",
|
||||
"database_type": DatabaseType.POSTGRES,
|
||||
"host": "localhost",
|
||||
"database": "test_db",
|
||||
"timeout": 30,
|
||||
"max_rows": 1000,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("results", [{"test_col": 1}]),
|
||||
("columns", ["test_col"]),
|
||||
("row_count", 1),
|
||||
("truncated", False),
|
||||
],
|
||||
test_mock={
|
||||
"execute_query": lambda *_args, **_kwargs: (
|
||||
[{"test_col": 1}],
|
||||
["test_col"],
|
||||
-1,
|
||||
False,
|
||||
),
|
||||
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_host_allowed(host: str) -> list[str]:
|
||||
"""Validate that the given host is not a private/blocked address.
|
||||
|
||||
Returns the list of resolved IP addresses so the caller can pin the
|
||||
connection to the validated IP (preventing DNS rebinding / TOCTOU).
|
||||
Raises ValueError or OSError if the host is blocked.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return await resolve_and_check_blocked(host)
|
||||
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Delegates to ``_execute_query`` in ``sql_query_helpers``.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return _execute_query(
|
||||
connection_url=connection_url,
|
||||
query=query,
|
||||
timeout=timeout,
|
||||
max_rows=max_rows,
|
||||
read_only=read_only,
|
||||
database_type=database_type,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: DatabaseCredentials,
|
||||
**_kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
# Validate query structure and read-only constraints.
|
||||
error = self._validate_query(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Validate host and resolve for SSRF protection.
|
||||
host, pinned_host, error = await self._resolve_host(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Build connection URL and execute.
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
|
||||
username=username,
|
||||
password=credentials.password.get_secret_value(),
|
||||
host=pinned_host,
|
||||
port=port,
|
||||
database=input_data.database,
|
||||
)
|
||||
conn_str = connection_url.render_as_string(hide_password=True)
|
||||
db_name = input_data.database
|
||||
|
||||
def _sanitize(err: Exception) -> str:
|
||||
return _sanitize_error(
|
||||
str(err).strip(),
|
||||
conn_str,
|
||||
host=pinned_host,
|
||||
original_host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
database=db_name,
|
||||
)
|
||||
|
||||
try:
|
||||
results, columns, affected, truncated = await asyncio.to_thread(
|
||||
self.execute_query,
|
||||
connection_url=connection_url,
|
||||
query=input_data.query,
|
||||
timeout=input_data.timeout,
|
||||
max_rows=input_data.max_rows,
|
||||
read_only=input_data.read_only,
|
||||
database_type=input_data.database_type,
|
||||
)
|
||||
yield "results", results
|
||||
yield "columns", columns
|
||||
yield "row_count", len(results)
|
||||
yield "truncated", truncated
|
||||
if affected >= 0:
|
||||
yield "affected_rows", affected
|
||||
except OperationalError as e:
|
||||
yield (
|
||||
"error",
|
||||
self._classify_operational_error(
|
||||
_sanitize(e),
|
||||
input_data.timeout,
|
||||
),
|
||||
)
|
||||
except ProgrammingError as e:
|
||||
yield "error", f"SQL error: {_sanitize(e)}"
|
||||
except DBAPIError as e:
|
||||
yield "error", f"Database error: {_sanitize(e)}"
|
||||
except ModuleNotFoundError:
|
||||
yield (
|
||||
"error",
|
||||
(
|
||||
f"Database driver not available for "
|
||||
f"{input_data.database_type.value}. "
|
||||
f"Please contact the platform administrator."
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
|
||||
"""Validate query structure and read-only constraints."""
|
||||
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
|
||||
if stmt_error:
|
||||
return stmt_error
|
||||
assert parsed_stmt is not None
|
||||
if input_data.read_only:
|
||||
return _validate_query_is_read_only(parsed_stmt)
|
||||
return None
|
||||
|
||||
async def _resolve_host(
|
||||
self, input_data: "SQLQueryBlock.Input"
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
return "", "", "Database host is required."
|
||||
if host.startswith("/"):
|
||||
return host, "", "Unix socket connections are not allowed."
|
||||
try:
|
||||
resolved_ips = await self.check_host_allowed(host)
|
||||
except (ValueError, OSError) as e:
|
||||
return host, "", f"Blocked host: {str(e).strip()}"
|
||||
return host, resolved_ips[0], None
|
||||
|
||||
@staticmethod
|
||||
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
|
||||
"""Classify an already-sanitized OperationalError for user display."""
|
||||
lower = sanitized_msg.lower()
|
||||
if "timeout" in lower or "cancel" in lower:
|
||||
return f"Query timed out after {timeout}s."
|
||||
if "connect" in lower:
|
||||
return f"Failed to connect to database: {sanitized_msg}"
|
||||
return f"Database error: {sanitized_msg}"
|
||||
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
1851
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
File diff suppressed because it is too large
Load Diff
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
430
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import re
|
||||
from datetime import date, datetime, time
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
# MySQL file exfiltration: LOAD DATA LOCAL INFILE reads server/client files
|
||||
"LOAD",
|
||||
# MySQL REPLACE is INSERT-or-UPDATE; data modification
|
||||
"REPLACE",
|
||||
# ANSI MERGE (UPSERT) modifies data
|
||||
"MERGE",
|
||||
# MSSQL BULK INSERT loads external files into tables
|
||||
"BULK",
|
||||
# MSSQL EXEC / EXEC sp_name runs stored procedures (arbitrary code)
|
||||
"EXEC",
|
||||
}
|
||||
|
||||
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
# Connection timeout in seconds passed to the DBAPI driver (connect_timeout /
|
||||
# login_timeout). This bounds how long the driver waits to establish a TCP
|
||||
# connection to the database server. It is separate from the per-statement
|
||||
# timeout configured via SET commands inside _configure_session().
|
||||
_CONNECT_TIMEOUT_SECONDS = 10
|
||||
|
||||
# Default ports for each database type.
|
||||
_DATABASE_TYPE_DEFAULT_PORT = {
|
||||
DatabaseType.POSTGRES: 5432,
|
||||
DatabaseType.MYSQL: 3306,
|
||||
DatabaseType.MSSQL: 1433,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(
|
||||
error_msg: str,
|
||||
connection_string: str,
|
||||
*,
|
||||
host: str = "",
|
||||
original_host: str = "",
|
||||
username: str = "",
|
||||
port: int = 0,
|
||||
database: str = "",
|
||||
) -> str:
|
||||
"""Remove connection string, credentials, and infrastructure details
|
||||
from error messages so they are safe to expose to the LLM.
|
||||
|
||||
Scrubs:
|
||||
- The full connection string
|
||||
- URL-embedded credentials (``://user:pass@``)
|
||||
- ``password=<value>`` key-value pairs
|
||||
- The database hostname / IP used for the connection
|
||||
- The original (pre-resolution) hostname provided by the user
|
||||
- Any IPv4 addresses that appear in the message
|
||||
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
|
||||
- The database username
|
||||
- The database port number
|
||||
- The database name
|
||||
"""
|
||||
sanitized = error_msg.replace(connection_string, "<connection_string>")
|
||||
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
|
||||
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
|
||||
|
||||
# Replace the known host (may be an IP already) before the generic IP pass.
|
||||
# Also replace the original (pre-DNS-resolution) hostname if it differs.
|
||||
if original_host and original_host != host:
|
||||
sanitized = sanitized.replace(original_host, "<host>")
|
||||
if host:
|
||||
sanitized = sanitized.replace(host, "<host>")
|
||||
|
||||
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
|
||||
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
|
||||
|
||||
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
|
||||
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
|
||||
|
||||
# Replace the database username (handles double-quoted, single-quoted,
|
||||
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
|
||||
if username:
|
||||
sanitized = re.sub(
|
||||
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
|
||||
"for user <user>",
|
||||
sanitized,
|
||||
)
|
||||
# Catch remaining bare occurrences in various quote styles:
|
||||
# - PostgreSQL: "FATAL: role "myuser" does not exist"
|
||||
# - MySQL: "Access denied for user 'myuser'@'host'"
|
||||
# - MSSQL: "Login failed for user 'myuser'"
|
||||
sanitized = sanitized.replace(f'"{username}"', "<user>")
|
||||
sanitized = sanitized.replace(f"'{username}'", "<user>")
|
||||
|
||||
# Replace the port number (handles "port 5432" and ":5432" formats)
|
||||
if port:
|
||||
port_str = re.escape(str(port))
|
||||
sanitized = re.sub(
|
||||
r"(?:port |:)" + port_str + r"(?![0-9])",
|
||||
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
|
||||
sanitized,
|
||||
)
|
||||
|
||||
# Replace the database name to avoid leaking internal infrastructure names.
|
||||
# Use word-boundary regex to prevent mangling when the database name is a
|
||||
# common substring (e.g. "test", "data", "on").
|
||||
if database:
|
||||
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract keyword tokens from a parsed SQL statement.
|
||||
|
||||
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
|
||||
tokens. String literals and identifiers have different token types, so
|
||||
they are naturally excluded from the result.
|
||||
"""
|
||||
return [
|
||||
token.normalized.upper()
|
||||
for token in parsed.flatten()
|
||||
if token.ttype
|
||||
in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""Check if a statement contains a disallowed ``INTO`` clause.
|
||||
|
||||
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
|
||||
a query result into a session-scoped user variable. All other forms of
|
||||
``INTO`` are data-modifying or file-writing and must be blocked:
|
||||
|
||||
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL – creates a table)
|
||||
* ``SELECT ... INTO OUTFILE`` (MySQL – writes to the filesystem)
|
||||
* ``SELECT ... INTO DUMPFILE`` (MySQL – writes to the filesystem)
|
||||
* ``INSERT INTO ...`` (already blocked by INSERT being in the
|
||||
disallowed set, but we reject INTO as well for defense-in-depth)
|
||||
|
||||
Returns ``True`` if the statement contains a disallowed ``INTO``.
|
||||
"""
|
||||
flat = list(stmt.flatten())
|
||||
for i, token in enumerate(flat):
|
||||
if not (
|
||||
token.ttype in (sqlparse.tokens.Keyword,)
|
||||
and token.normalized.upper() == "INTO"
|
||||
):
|
||||
continue
|
||||
|
||||
# Look at the first non-whitespace token after INTO.
|
||||
j = i + 1
|
||||
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
|
||||
j += 1
|
||||
|
||||
if j >= len(flat):
|
||||
# INTO at the very end – malformed, block it.
|
||||
return True
|
||||
|
||||
next_token = flat[j]
|
||||
# MySQL user variable: either a single Name starting with "@"
|
||||
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
|
||||
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
|
||||
"@"
|
||||
):
|
||||
continue
|
||||
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
|
||||
continue
|
||||
|
||||
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
|
||||
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
|
||||
|
||||
Accepts an already-parsed statement from ``_validate_single_statement``
|
||||
to avoid re-parsing. Checks:
|
||||
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
|
||||
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
|
||||
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt.get_type() != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
for kw in _extract_keyword_tokens(stmt):
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
|
||||
if _has_disallowed_into(stmt):
|
||||
return "Disallowed SQL keyword: INTO"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_single_statement(
|
||||
query: str,
|
||||
) -> tuple[str | None, sqlparse.sql.Statement | None]:
|
||||
"""Validate that the query contains exactly one non-empty SQL statement.
|
||||
|
||||
Returns (error_message, parsed_statement). If error_message is not None,
|
||||
the query is invalid and parsed_statement will be None.
|
||||
"""
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements and comment-only statements
|
||||
statements = [
|
||||
s
|
||||
for s in statements
|
||||
if s.tokens
|
||||
and str(s).strip()
|
||||
and not all(
|
||||
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
|
||||
)
|
||||
]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Reject multiple statements -- prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed.", None
|
||||
|
||||
return None, statements[0]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert database-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# NaN / Infinity are not valid JSON numbers; serialize as strings.
|
||||
if value.is_nan() or value.is_infinite():
|
||||
return str(value)
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date, time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
return bytes(value).hex()
|
||||
if isinstance(value, bytes):
|
||||
return value.hex()
|
||||
return value
|
||||
|
||||
|
||||
def _configure_session(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
timeout_ms: str,
|
||||
read_only: bool,
|
||||
) -> None:
|
||||
"""Set session-level timeout and read-only mode for the given dialect.
|
||||
|
||||
Timeout limitations by database:
|
||||
|
||||
* **PostgreSQL** – ``statement_timeout`` reliably cancels any running
|
||||
statement (SELECT or DML) after the configured duration.
|
||||
* **MySQL** – ``MAX_EXECUTION_TIME`` only applies to **read-only SELECT**
|
||||
statements. DML (INSERT/UPDATE/DELETE) and DDL are *not* bounded by
|
||||
this hint; they rely on the server's ``wait_timeout`` /
|
||||
``interactive_timeout`` instead. There is no session-level setting in
|
||||
MySQL that reliably cancels long-running writes.
|
||||
* **MSSQL** – ``SET LOCK_TIMEOUT`` only limits how long the server waits
|
||||
to acquire a **lock**. CPU-bound queries (e.g. large scans, hash
|
||||
joins) that do not block on locks will *not* be cancelled. MSSQL has
|
||||
no session-level ``statement_timeout`` equivalent; the closest
|
||||
mechanism is Resource Governor (requires sysadmin configuration) or
|
||||
``CONTEXT_INFO``-based external monitoring.
|
||||
|
||||
Note: SQLite is not supported by this block. The ``_configure_session``
|
||||
function is a no-op for unrecognised dialect names, so an SQLite engine
|
||||
would skip all SET commands silently. The block's ``DatabaseType`` enum
|
||||
intentionally excludes SQLite.
|
||||
"""
|
||||
if dialect_name == "postgresql":
|
||||
conn.execute(text("SET statement_timeout = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET default_transaction_read_only = ON"))
|
||||
elif dialect_name == "mysql":
|
||||
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
|
||||
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
|
||||
# setting; they rely on the database's wait_timeout instead.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
|
||||
elif dialect_name == "mssql":
|
||||
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms) only.
|
||||
# CPU-bound queries without lock contention are NOT cancelled.
|
||||
# See docstring above for full limitations.
|
||||
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
|
||||
# MSSQL lacks a session-level read-only mode like
|
||||
# PostgreSQL/MySQL. Read-only enforcement is handled by
|
||||
# the SQL validation layer (_validate_query_is_read_only)
|
||||
# and the ROLLBACK in the finally block.
|
||||
|
||||
|
||||
def _run_in_transaction(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
query: str,
|
||||
max_rows: int,
|
||||
read_only: bool,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a query inside an explicit transaction, returning results.
|
||||
|
||||
Returns ``(rows, columns, affected_rows, truncated)`` where *truncated*
|
||||
is ``True`` when ``fetchmany`` returned exactly ``max_rows`` rows,
|
||||
indicating that additional rows may exist in the result set.
|
||||
"""
|
||||
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
|
||||
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
|
||||
conn.execute(text(begin_stmt))
|
||||
try:
|
||||
result = conn.execute(text(query))
|
||||
affected = result.rowcount if not result.returns_rows else -1
|
||||
columns = list(result.keys()) if result.returns_rows else []
|
||||
rows = result.fetchmany(max_rows) if result.returns_rows else []
|
||||
truncated = len(rows) == max_rows
|
||||
results = [
|
||||
{col: _serialize_value(val) for col, val in zip(columns, row)}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
try:
|
||||
conn.execute(text("ROLLBACK"))
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
|
||||
return results, columns, affected, truncated
|
||||
|
||||
|
||||
def _execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int, bool]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows, truncated).
|
||||
|
||||
Uses SQLAlchemy to connect to any supported database.
|
||||
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
|
||||
``truncated`` is ``True`` when the result set was capped by ``max_rows``.
|
||||
For write queries, affected_rows contains the rowcount from the driver.
|
||||
When ``read_only`` is True, the database session is set to read-only
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
timeout_key = (
|
||||
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
|
||||
)
|
||||
engine = create_engine(
|
||||
connection_url, connect_args={timeout_key: _CONNECT_TIMEOUT_SECONDS}
|
||||
)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# Use AUTOCOMMIT so SET commands take effect immediately.
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# Compute timeout in milliseconds. The value is Pydantic-validated
|
||||
# (ge=1, le=120), but we use int() as defense-in-depth.
|
||||
# NOTE: SET commands do not support bind parameters in most
|
||||
# databases, so we use str(int(...)) for safe interpolation.
|
||||
timeout_ms = str(int(timeout * 1000))
|
||||
|
||||
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
|
||||
return _run_in_transaction(
|
||||
conn, engine.dialect.name, query, max_rows, read_only
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
from stagehand import AsyncStagehand
|
||||
from stagehand.types.session_act_params import Options as ActOptions
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -28,46 +23,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -148,13 +103,10 @@ class StagehandObserveBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -185,32 +137,28 @@ class StagehandObserveBlock(Block):
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
observe_response = await session.observe(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
for result in observe_response.data.result:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
@@ -242,24 +190,22 @@ class StagehandActBlock(Block):
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms for each action.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
message: str = SchemaField(description="Details about the action's execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -282,32 +228,33 @@ class StagehandActBlock(Block):
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
for action in input_data.action:
|
||||
act_options = ActOptions(
|
||||
variables={k: v for k, v in input_data.variables.items()},
|
||||
timeout=input_data.timeout_ms,
|
||||
)
|
||||
act_response = await session.act(
|
||||
input=action,
|
||||
options=act_options,
|
||||
)
|
||||
result = act_response.data.result
|
||||
yield "success", result.success
|
||||
yield "message", result.message
|
||||
yield "action", result.action_description
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
@@ -335,13 +282,10 @@ class StagehandExtractBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -367,24 +311,21 @@ class StagehandExtractBlock(Block):
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
extract_response = await session.extract(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
yield "extraction", str(extract_response.data.result)
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
@@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks._base import Block, BlockSchemaInput
|
||||
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
@@ -279,3 +281,113 @@ class TestAutoCredentialsFieldsValidation:
|
||||
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
|
||||
for backward compatibility with existing agent JSON."""
|
||||
legacy_data = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
instance = AgentInputBlock.Input.model_construct(**legacy_data)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
"enum" not in schema
|
||||
), "AgentInputBlock should not produce enum from legacy placeholder_values"
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, options=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with legacy placeholder_values
|
||||
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
|
||||
legacy_input_default = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentInputBlock.Input, legacy_input_default),
|
||||
)
|
||||
url_props = result["properties"]["url"]
|
||||
assert (
|
||||
"enum" not in url_props
|
||||
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -207,6 +207,51 @@ class TestXMLParserBlockSecurity:
|
||||
pass
|
||||
|
||||
|
||||
class TestXMLParserBlockSyntaxErrors:
|
||||
"""XML syntax errors should raise ValueError (not SyntaxError).
|
||||
|
||||
This ensures the base Block.execute() wraps them as BlockExecutionError
|
||||
(expected / user-caused) instead of BlockUnknownError (unexpected / alerts
|
||||
Sentry).
|
||||
"""
|
||||
|
||||
async def test_unclosed_tag_raises_value_error(self):
|
||||
"""Unclosed tags should raise ValueError, not SyntaxError."""
|
||||
block = XMLParserBlock()
|
||||
bad_xml = "<root><unclosed>"
|
||||
|
||||
with pytest.raises(ValueError, match="Unclosed tag"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
async def test_unexpected_closing_tag_raises_value_error(self):
|
||||
"""Extra closing tags should raise ValueError, not SyntaxError."""
|
||||
block = XMLParserBlock()
|
||||
bad_xml = "</unexpected>"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
async def test_empty_xml_raises_value_error(self):
|
||||
"""Empty XML input should raise ValueError."""
|
||||
block = XMLParserBlock()
|
||||
|
||||
with pytest.raises(ValueError, match="XML input is empty"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml="")):
|
||||
pass
|
||||
|
||||
async def test_syntax_error_from_parser_becomes_value_error(self):
|
||||
"""SyntaxErrors from gravitasml library become ValueError (BlockExecutionError)."""
|
||||
block = XMLParserBlock()
|
||||
# Malformed XML that might trigger a SyntaxError from the parser
|
||||
bad_xml = "<root><child>no closing"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=bad_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
# TEST_CREDENTIALS_INPUT is a plain dict that satisfies AICredentials at runtime
|
||||
# but not at the type level. Cast once here to avoid per-test suppressors.
|
||||
_TEST_AI_CREDENTIALS = cast(llm.AICredentials, llm.TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
class TestLLMStatsTracking:
|
||||
"""Test that LLM blocks correctly track token usage statistics."""
|
||||
@@ -479,6 +488,154 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == {"result": "test"}
|
||||
|
||||
|
||||
class TestAIConversationBlockValidation:
|
||||
"""Test that AIConversationBlock validates inputs before calling the LLM."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_and_empty_prompt_raises_error(self):
|
||||
"""Empty messages with no prompt should raise ValueError, not a cryptic API error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages_with_prompt_succeeds(self):
|
||||
"""Empty messages but a non-empty prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "OK"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[],
|
||||
prompt="Hello, how are you?",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "OK"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonempty_messages_with_empty_prompt_succeeds(self):
|
||||
"""Non-empty messages with no prompt should proceed without error."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
return {"response": "response from conversation"}
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["response"] == "response from conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_content_raises_error(self):
|
||||
"""Messages with empty content strings should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_whitespace_content_raises_error(self):
|
||||
"""Messages with whitespace-only content should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": " "}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_entry_raises_error(self):
|
||||
"""Messages list containing None should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[None],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_empty_dict_raises_error(self):
|
||||
"""Messages list containing empty dict should be treated as no messages."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_none_content_raises_error(self):
|
||||
"""Messages with content=None should not crash with AttributeError."""
|
||||
block = llm.AIConversationBlock()
|
||||
|
||||
input_data = llm.AIConversationBlock.Input(
|
||||
messages=[{"role": "user", "content": None}],
|
||||
prompt="",
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="no messages and no prompt"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
class TestAITextSummarizerValidation:
|
||||
"""Test that AITextSummarizerBlock validates LLM responses are strings."""
|
||||
|
||||
@@ -655,3 +812,178 @@ class TestAITextSummarizerValidation:
|
||||
error_message = str(exc_info.value)
|
||||
assert "Expected a string summary" in error_message
|
||||
assert "received dict" in error_message
|
||||
|
||||
|
||||
def _make_anthropic_status_error(status_code: int) -> anthropic.APIStatusError:
|
||||
"""Create an anthropic.APIStatusError with the given status code."""
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
response = httpx.Response(status_code, request=request)
|
||||
return anthropic.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
def _make_openai_status_error(status_code: int) -> openai.APIStatusError:
|
||||
"""Create an openai.APIStatusError with the given status code."""
|
||||
response = httpx.Response(
|
||||
status_code, request=httpx.Request("POST", "https://api.openai.com/v1/chat")
|
||||
)
|
||||
return openai.APIStatusError(
|
||||
f"Error code: {status_code}", response=response, body=None
|
||||
)
|
||||
|
||||
|
||||
class TestUserErrorStatusCodeHandling:
|
||||
"""Test that user-caused LLM API errors (401/403/429) break the retry loop
|
||||
and are logged as warnings, while server errors (500) trigger retries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_anthropic_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 Anthropic errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("status_code", [401, 403, 429])
|
||||
async def test_openai_user_error_breaks_retry_loop(self, status_code: int):
|
||||
"""401/403/429 OpenAI errors should break immediately, not retry."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_openai_status_error(status_code)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count == 1
|
||||
), f"Expected exactly 1 call for status {status_code}, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_error_retries(self):
|
||||
"""500 errors should be retried (not break immediately)."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _make_anthropic_status_error(500)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
assert (
|
||||
call_count > 1
|
||||
), f"Expected multiple retry attempts for 500, got {call_count}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_error_logs_warning_not_exception(self):
|
||||
"""User-caused errors should log with logger.warning, not logger.exception."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
raise _make_anthropic_status_error(401)
|
||||
|
||||
with patch.object(block, "llm_call", new=AsyncMock(side_effect=mock_llm_call)):
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"key": "desc"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=_TEST_AI_CREDENTIALS,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(llm.logger, "warning") as mock_warning,
|
||||
patch.object(llm.logger, "exception") as mock_exception,
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_exception.assert_not_called()
|
||||
|
||||
|
||||
class TestLlmModelMissing:
|
||||
"""Test that LlmModel handles provider-prefixed model names."""
|
||||
|
||||
def test_provider_prefixed_model_resolves(self):
|
||||
"""Provider-prefixed model string should resolve to the correct enum member."""
|
||||
assert (
|
||||
llm.LlmModel("anthropic/claude-sonnet-4-6")
|
||||
== llm.LlmModel.CLAUDE_4_6_SONNET
|
||||
)
|
||||
|
||||
def test_bare_model_still_works(self):
|
||||
"""Bare (non-prefixed) model string should still resolve correctly."""
|
||||
assert llm.LlmModel("claude-sonnet-4-6") == llm.LlmModel.CLAUDE_4_6_SONNET
|
||||
|
||||
def test_invalid_prefixed_model_raises(self):
|
||||
"""Unknown provider-prefixed model string should raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
llm.LlmModel("invalid/nonexistent-model")
|
||||
|
||||
def test_slash_containing_value_direct_lookup(self):
|
||||
"""Enum values with '/' (e.g., OpenRouter models) should resolve via direct lookup, not _missing_."""
|
||||
assert llm.LlmModel("google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
|
||||
def test_double_prefixed_slash_model(self):
|
||||
"""Double-prefixed value should still resolve by stripping first prefix."""
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for empty-choices guard in extract_openai_tool_calls() and extract_openai_reasoning()."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.llm import extract_openai_reasoning, extract_openai_tool_calls
|
||||
|
||||
|
||||
class TestExtractOpenaiToolCallsEmptyChoices:
|
||||
"""extract_openai_tool_calls() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
def test_returns_tool_calls_when_choices_present(self):
|
||||
tool = MagicMock()
|
||||
tool.id = "call_1"
|
||||
tool.type = "function"
|
||||
tool.function.name = "my_func"
|
||||
tool.function.arguments = '{"a": 1}'
|
||||
|
||||
message = MagicMock()
|
||||
message.tool_calls = [tool]
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_tool_calls(response)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].function.name == "my_func"
|
||||
|
||||
def test_returns_none_when_no_tool_calls(self):
|
||||
message = MagicMock()
|
||||
message.tool_calls = None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
|
||||
assert extract_openai_tool_calls(response) is None
|
||||
|
||||
|
||||
class TestExtractOpenaiReasoningEmptyChoices:
|
||||
"""extract_openai_reasoning() must return None when choices is empty."""
|
||||
|
||||
def test_returns_none_for_empty_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = []
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_none_for_none_choices(self):
|
||||
response = MagicMock()
|
||||
response.choices = None
|
||||
assert extract_openai_reasoning(response) is None
|
||||
|
||||
def test_returns_reasoning_from_choice(self):
|
||||
choice = MagicMock()
|
||||
choice.reasoning = "Step-by-step reasoning"
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr on message
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr on response
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result == "Step-by-step reasoning"
|
||||
|
||||
def test_returns_none_when_no_reasoning(self):
|
||||
choice = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
choice.message = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
|
||||
response = MagicMock(spec=[]) # no 'reasoning' attr
|
||||
response.choices = [choice]
|
||||
|
||||
result = extract_openai_reasoning(response)
|
||||
assert result is None
|
||||
@@ -57,7 +57,7 @@ async def execute_graph(
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -66,7 +66,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -108,10 +108,10 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
async def test_orchestrator_function_signature(server: SpinTestServer):
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data import graph
|
||||
|
||||
test_user = await create_test_user()
|
||||
@@ -120,7 +120,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
block_id=OrchestratorBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": creds,
|
||||
@@ -169,7 +169,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
tool_functions = await OrchestratorBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -198,12 +198,12 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||
async def test_orchestrator_tracks_llm_stats():
|
||||
"""Test that OrchestratorBlock correctly tracks LLM usage stats."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the llm.llm_call function to return controlled data
|
||||
mock_response = MagicMock()
|
||||
@@ -224,14 +224,14 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -274,12 +274,12 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
async def test_orchestrator_parameter_validation():
|
||||
"""Test that OrchestratorBlock correctly validates tool call parameters."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
@@ -327,13 +327,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -394,13 +394,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -454,13 +454,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -518,13 +518,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -562,12 +562,12 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_raw_response_conversion():
|
||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||
async def test_orchestrator_raw_response_conversion():
|
||||
"""Test that Orchestrator correctly handles different raw_response types with retry mechanism."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool functions
|
||||
mock_tool_functions = [
|
||||
@@ -637,7 +637,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
@@ -646,7 +646,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
# Second call returns successful response
|
||||
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -715,12 +715,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -771,12 +771,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
OrchestratorBlock,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -811,12 +811,12 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_mode():
|
||||
async def test_orchestrator_agent_mode():
|
||||
"""Test that agent mode executes tools directly and loops until finished."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call that requires multiple iterations
|
||||
mock_tool_call_1 = MagicMock()
|
||||
@@ -893,7 +893,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -929,7 +929,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
}
|
||||
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -969,12 +969,12 @@ async def test_smart_decision_maker_agent_mode():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_traditional_mode_default():
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock tool call
|
||||
mock_tool_call = MagicMock()
|
||||
@@ -1018,7 +1018,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
):
|
||||
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
input_data = OrchestratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -1060,12 +1060,12 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
"""Test that OrchestratorBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1074,13 +1074,14 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1091,12 +1092,12 @@ async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
async def test_orchestrator_falls_back_to_block_name():
|
||||
"""Test that OrchestratorBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1105,13 +1106,14 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
result = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1122,11 +1124,11 @@ async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
async def test_orchestrator_uses_customized_name_for_agents():
|
||||
"""Test that OrchestratorBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
@@ -1152,10 +1154,10 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -1166,11 +1168,11 @@ async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
async def test_orchestrator_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
@@ -1196,10 +1198,10 @@ async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
result = await OrchestratorBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_dict_fields():
|
||||
"""Test Orchestrator can handle dynamic dictionary fields (_#_) for any block"""
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -23,24 +23,24 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
source_name="tools_^_create_dict_~_name",
|
||||
sink_name="values_#_name", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_age",
|
||||
sink_name="values_#_age", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_city",
|
||||
sink_name="values_#_city", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -70,8 +70,8 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
|
||||
async def test_orchestrator_handles_dynamic_list_fields():
|
||||
"""Test Orchestrator can handle dynamic list fields (_$_) for any block"""
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -86,18 +86,18 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
signature = await OrchestratorBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||
"""Comprehensive tests for OrchestratorBlock dynamic field handling."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.text import MatchTextPatternBlock
|
||||
from backend.data.dynamic_fields import get_dynamic_field_description
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_dynamic_field_description_generation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_dict_fields():
|
||||
"""Test that function signatures are created correctly for dictionary dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
@@ -52,19 +52,19 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
source_name="tools_^_create_dict_~_values___name", # Sanitized source
|
||||
sink_name="values_#_name", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___age", # Sanitized source
|
||||
sink_name="values_#_age", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_values___email", # Sanitized source
|
||||
sink_name="values_#_email", # Original sink
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -100,7 +100,7 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_list_fields():
|
||||
"""Test that function signatures are created correctly for list dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
@@ -115,19 +115,19 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
source_name="tools_^_add_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_list_~_2",
|
||||
sink_name="entries_$_2", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -154,7 +154,7 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_block_function_signature_with_object_fields():
|
||||
"""Test that function signatures are created correctly for object dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node for MatchTextPatternBlock (simulating object fields)
|
||||
mock_node = Mock()
|
||||
@@ -169,13 +169,13 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
source_name="tools_^_extract_~_user_name",
|
||||
sink_name="data_@_user_name", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_extract_~_user_email",
|
||||
sink_name="data_@_user_email", # Dynamic object field
|
||||
sink_id="extract_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -197,11 +197,11 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mock the database client and connected nodes
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db:
|
||||
mock_client = AsyncMock()
|
||||
mock_db.return_value = mock_client
|
||||
@@ -281,7 +281,7 @@ async def test_create_tool_node_signatures():
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_yielding_with_dynamic_fields():
|
||||
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# No more sanitized mapping needed since we removed sanitization
|
||||
|
||||
@@ -309,13 +309,13 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager, patch.object(
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
@@ -420,7 +420,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_regular_and_dynamic_fields():
|
||||
"""Test handling of blocks with both regular and dynamic fields."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Create a mock node
|
||||
mock_node = Mock()
|
||||
@@ -450,19 +450,19 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
source_name="tools_^_test_~_regular",
|
||||
sink_name="regular_field", # Regular field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key",
|
||||
sink_name="values_#_key1", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_test_~_dict_key2",
|
||||
sink_name="values_#_key2", # Dynamic dict field
|
||||
sink_id="test_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
source_id="orchestrator_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -488,7 +488,7 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_errors_dont_pollute_conversation():
|
||||
"""Test that validation errors are only used during retries and don't pollute the conversation."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Track conversation history changes
|
||||
conversation_snapshots = []
|
||||
@@ -535,7 +535,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the LLM call
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
|
||||
"backend.blocks.orchestrator.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm:
|
||||
mock_llm.side_effect = mock_llm_call
|
||||
|
||||
@@ -565,7 +565,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the database manager to avoid HTTP calls during tool execution
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client"
|
||||
) as mock_db_manager:
|
||||
# Set up the mock database manager for agent mode
|
||||
mock_db_client = AsyncMock()
|
||||
@@ -0,0 +1,202 @@
|
||||
"""Tests for ExecutionMode enum and provider validation in the orchestrator.
|
||||
|
||||
Covers:
|
||||
- ExecutionMode enum members exist and have stable values
|
||||
- EXTENDED_THINKING provider validation (anthropic/open_router allowed, others rejected)
|
||||
- EXTENDED_THINKING model-name validation (must start with "claude")
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.orchestrator import ExecutionMode, OrchestratorBlock
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExecutionMode enum integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecutionModeEnum:
|
||||
"""Guard against accidental renames or removals of enum members."""
|
||||
|
||||
def test_built_in_exists(self):
|
||||
assert hasattr(ExecutionMode, "BUILT_IN")
|
||||
assert ExecutionMode.BUILT_IN.value == "built_in"
|
||||
|
||||
def test_extended_thinking_exists(self):
|
||||
assert hasattr(ExecutionMode, "EXTENDED_THINKING")
|
||||
assert ExecutionMode.EXTENDED_THINKING.value == "extended_thinking"
|
||||
|
||||
def test_exactly_two_members(self):
|
||||
"""If a new mode is added, this test should be updated intentionally."""
|
||||
assert set(ExecutionMode.__members__.keys()) == {
|
||||
"BUILT_IN",
|
||||
"EXTENDED_THINKING",
|
||||
}
|
||||
|
||||
def test_string_enum(self):
|
||||
"""ExecutionMode is a str enum so it serialises cleanly to JSON."""
|
||||
assert isinstance(ExecutionMode.BUILT_IN, str)
|
||||
assert isinstance(ExecutionMode.EXTENDED_THINKING, str)
|
||||
|
||||
def test_round_trip_from_value(self):
|
||||
"""Constructing from the string value should return the same member."""
|
||||
assert ExecutionMode("built_in") is ExecutionMode.BUILT_IN
|
||||
assert ExecutionMode("extended_thinking") is ExecutionMode.EXTENDED_THINKING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider validation (inline in OrchestratorBlock.run)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_model_stub(provider: str, value: str):
|
||||
"""Create a lightweight stub that behaves like LlmModel for validation."""
|
||||
metadata = MagicMock()
|
||||
metadata.provider = provider
|
||||
stub = MagicMock()
|
||||
stub.metadata = metadata
|
||||
stub.value = value
|
||||
return stub
|
||||
|
||||
|
||||
class TestExtendedThinkingProviderValidation:
|
||||
"""The orchestrator rejects EXTENDED_THINKING for non-Anthropic providers."""
|
||||
|
||||
def test_anthropic_provider_accepted(self):
|
||||
"""provider='anthropic' + claude model should not raise."""
|
||||
model = _make_model_stub("anthropic", "claude-opus-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_open_router_provider_accepted(self):
|
||||
"""provider='open_router' + claude model should not raise."""
|
||||
model = _make_model_stub("open_router", "claude-sonnet-4-6")
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
def test_openai_provider_rejected(self):
|
||||
"""provider='openai' should be rejected for EXTENDED_THINKING."""
|
||||
model = _make_model_stub("openai", "gpt-4o")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_groq_provider_rejected(self):
|
||||
model = _make_model_stub("groq", "llama-3.3-70b-versatile")
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_non_claude_model_rejected_even_if_anthropic_provider(self):
|
||||
"""A hypothetical non-Claude model with provider='anthropic' is rejected."""
|
||||
model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
model_name = model.value
|
||||
assert not model_name.startswith("claude")
|
||||
|
||||
def test_real_gpt4o_model_rejected(self):
|
||||
"""Verify a real LlmModel enum member (GPT4O) fails the provider check."""
|
||||
model = LlmModel.GPT4O
|
||||
provider = model.metadata.provider
|
||||
assert provider not in ("anthropic", "open_router")
|
||||
|
||||
def test_real_claude_model_passes(self):
|
||||
"""Verify a real LlmModel enum member (CLAUDE_4_6_SONNET) passes."""
|
||||
model = LlmModel.CLAUDE_4_6_SONNET
|
||||
provider = model.metadata.provider
|
||||
model_name = model.value
|
||||
assert provider in ("anthropic", "open_router")
|
||||
assert model_name.startswith("claude")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: exercise the validation branch via OrchestratorBlock.run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input_data(model, execution_mode=ExecutionMode.EXTENDED_THINKING):
|
||||
"""Build a minimal MagicMock that satisfies OrchestratorBlock.run's early path."""
|
||||
inp = MagicMock()
|
||||
inp.execution_mode = execution_mode
|
||||
inp.model = model
|
||||
inp.prompt = "test"
|
||||
inp.sys_prompt = ""
|
||||
inp.conversation_history = []
|
||||
inp.last_tool_output = None
|
||||
inp.prompt_values = {}
|
||||
return inp
|
||||
|
||||
|
||||
async def _collect_run_outputs(block, input_data, **kwargs):
|
||||
"""Exhaust the OrchestratorBlock.run async generator, collecting outputs."""
|
||||
outputs = []
|
||||
async for item in block.run(input_data, **kwargs):
|
||||
outputs.append(item)
|
||||
return outputs
|
||||
|
||||
|
||||
class TestExtendedThinkingValidationRaisesInBlock:
|
||||
"""Call OrchestratorBlock.run far enough to trigger the ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_anthropic_provider_raises_valueerror(self):
|
||||
"""EXTENDED_THINKING + openai provider raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
input_data = _make_input_data(model=LlmModel.GPT4O)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="Anthropic-compatible"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_claude_model_with_anthropic_provider_raises(self):
|
||||
"""A model with anthropic provider but non-claude name raises ValueError."""
|
||||
block = OrchestratorBlock()
|
||||
fake_model = _make_model_stub("anthropic", "not-a-claude-model")
|
||||
input_data = _make_input_data(model=fake_model)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
block,
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
pytest.raises(ValueError, match="only supports Claude models"),
|
||||
):
|
||||
await _collect_run_outputs(
|
||||
block,
|
||||
input_data,
|
||||
credentials=MagicMock(),
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
user_id="u",
|
||||
graph_version=1,
|
||||
execution_context=MagicMock(),
|
||||
execution_processor=MagicMock(),
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for SmartDecisionMakerBlock compatibility with the OpenAI Responses API.
|
||||
"""Tests for OrchestratorBlock compatibility with the OpenAI Responses API.
|
||||
|
||||
The SmartDecisionMakerBlock manages conversation history in the Chat Completions
|
||||
The OrchestratorBlock manages conversation history in the Chat Completions
|
||||
format, but OpenAI models now use the Responses API which has a fundamentally
|
||||
different conversation structure. These tests document:
|
||||
|
||||
@@ -27,8 +27,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
from backend.blocks.orchestrator import (
|
||||
OrchestratorBlock,
|
||||
_combine_tool_responses,
|
||||
_convert_raw_response_to_dict,
|
||||
_create_tool_response,
|
||||
@@ -733,7 +733,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_no_reasoning_no_tools(self):
|
||||
"""Dict raw_response, no reasoning → appends assistant dict."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": "hi"})
|
||||
block._update_conversation(prompt, resp)
|
||||
@@ -741,7 +741,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_with_reasoning_no_tool_calls(self):
|
||||
"""Reasoning present, no tool calls → reasoning prepended."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response(
|
||||
{"role": "assistant", "content": "answer"},
|
||||
@@ -757,7 +757,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_dict_raw_response_with_reasoning_and_anthropic_tool_calls(self):
|
||||
"""Reasoning + Anthropic tool_use in content → reasoning skipped."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
raw = {
|
||||
"role": "assistant",
|
||||
@@ -772,7 +772,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_with_tool_outputs(self):
|
||||
"""Tool outputs → extended onto prompt."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": None})
|
||||
outputs = [{"role": "tool", "tool_call_id": "call_1", "content": "r"}]
|
||||
@@ -782,7 +782,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_without_tool_outputs(self):
|
||||
"""No tool outputs → only assistant message appended."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response({"role": "assistant", "content": "done"})
|
||||
block._update_conversation(prompt, resp, None)
|
||||
@@ -790,7 +790,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_string_raw_response(self):
|
||||
"""Ollama string → wrapped as assistant dict."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response("hello from ollama")
|
||||
block._update_conversation(prompt, resp)
|
||||
@@ -800,7 +800,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_responses_api_text_response_produces_valid_items(self):
|
||||
"""Responses API text response → conversation items must have valid role."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "user"},
|
||||
@@ -820,7 +820,7 @@ class TestUpdateConversation:
|
||||
|
||||
def test_responses_api_function_call_produces_valid_items(self):
|
||||
"""Responses API function_call → conversation items must have valid type."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
prompt: list[dict] = []
|
||||
resp = self._make_response(
|
||||
_MockResponse(output=[_MockFunctionCall("tool", "{}", call_id="call_1")])
|
||||
@@ -856,7 +856,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# First response: tool call
|
||||
mock_tc = MagicMock()
|
||||
@@ -936,7 +936,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db,
|
||||
), patch(
|
||||
"backend.executor.manager.async_update_node_execution_status",
|
||||
@@ -945,7 +945,7 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
"backend.integrations.creds_manager.IntegrationCredentialsManager"
|
||||
):
|
||||
|
||||
inp = SmartDecisionMakerBlock.Input(
|
||||
inp = OrchestratorBlock.Input(
|
||||
prompt="Improve this",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
@@ -992,7 +992,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
|
||||
"""Traditional mode: the yielded conversation must contain only valid items."""
|
||||
import backend.blocks.llm as llm_module
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
block = OrchestratorBlock()
|
||||
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.function.name = "my_tool"
|
||||
@@ -1028,7 +1028,7 @@ async def test_traditional_mode_conversation_valid_for_responses_api():
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock, return_value=resp
|
||||
), patch.object(block, "_create_tool_node_signatures", return_value=tool_sigs):
|
||||
|
||||
inp = SmartDecisionMakerBlock.Input(
|
||||
inp = OrchestratorBlock.Input(
|
||||
prompt="Do it",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,7 @@ class XMLParserBlock(Block):
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
raise ValueError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
@@ -53,7 +53,7 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
raise ValueError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
@@ -76,4 +76,7 @@ class XMLParserBlock(Block):
|
||||
except ValueError as val_e:
|
||||
raise ValueError(f"Validation error for dict:{val_e}") from val_e
|
||||
except SyntaxError as syn_e:
|
||||
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
# Raise as ValueError so the base Block.execute() wraps it as
|
||||
# BlockExecutionError (expected user-caused failure) instead of
|
||||
# BlockUnknownError (unexpected platform error that alerts Sentry).
|
||||
raise ValueError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -31,7 +31,7 @@ async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await create_chat_session(test_user_id, dry_run=False)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
|
||||
@@ -0,0 +1,633 @@
|
||||
"""Unit tests for baseline service pure-logic helpers.
|
||||
|
||||
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_conversation_updater,
|
||||
_BaselineStreamState,
|
||||
_compress_session_messages,
|
||||
_ThinkingStripper,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.prompt import CompressResult
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
class TestBaselineStreamState:
|
||||
def test_defaults(self):
|
||||
state = _BaselineStreamState()
|
||||
assert state.pending_events == []
|
||||
assert state.assistant_text == ""
|
||||
assert state.text_started is False
|
||||
assert state.turn_prompt_tokens == 0
|
||||
assert state.turn_completion_tokens == 0
|
||||
assert state.text_block_id # Should be a UUID string
|
||||
|
||||
def test_mutable_fields(self):
|
||||
state = _BaselineStreamState()
|
||||
state.assistant_text = "hello"
|
||||
state.turn_prompt_tokens = 100
|
||||
state.turn_completion_tokens = 50
|
||||
assert state.assistant_text == "hello"
|
||||
assert state.turn_prompt_tokens == 100
|
||||
assert state.turn_completion_tokens == 50
|
||||
|
||||
|
||||
class TestBaselineConversationUpdater:
|
||||
"""Tests for _baseline_conversation_updater which updates the OpenAI
|
||||
message list and transcript builder after each LLM call."""
|
||||
|
||||
def _make_transcript_builder(self) -> TranscriptBuilder:
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("test question")
|
||||
return builder
|
||||
|
||||
def test_text_only_response(self):
|
||||
"""When the LLM returns text without tool calls, the updater appends
|
||||
a single assistant message and records it in the transcript."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Hello, world!",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Hello, world!"
|
||||
# Transcript should have user + assistant
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
def test_tool_calls_response(self):
|
||||
"""When the LLM returns tool calls, the updater appends the assistant
|
||||
message with tool_calls and tool result messages."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text="Let me search...",
|
||||
tool_calls=[
|
||||
LLMToolCall(
|
||||
id="tc_1",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_1",
|
||||
tool_name="search",
|
||||
content="Found result",
|
||||
),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Messages: assistant (with tool_calls) + tool result
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "assistant"
|
||||
assert messages[0]["content"] == "Let me search..."
|
||||
assert len(messages[0]["tool_calls"]) == 1
|
||||
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
|
||||
assert messages[1]["role"] == "tool"
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[1]["content"] == "Found result"
|
||||
|
||||
# Transcript: user + assistant(tool_use) + user(tool_result)
|
||||
assert builder.entry_count == 3
|
||||
|
||||
def test_tool_calls_without_text(self):
|
||||
"""Tool calls without accompanying text should still work."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="run", arguments="{}"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "content" not in messages[0] # No text content
|
||||
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
|
||||
|
||||
def test_no_text_no_tools(self):
|
||||
"""When the response has no text and no tool calls, nothing is appended."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert len(messages) == 0
|
||||
# Only the user entry from setup
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool calls in a single response are all recorded."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
|
||||
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
|
||||
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# 1 assistant + 2 tool results
|
||||
assert len(messages) == 3
|
||||
assert len(messages[0]["tool_calls"]) == 2
|
||||
assert messages[1]["tool_call_id"] == "tc_1"
|
||||
assert messages[2]["tool_call_id"] == "tc_2"
|
||||
|
||||
def test_invalid_tool_arguments_handled(self):
|
||||
"""Tool call with invalid JSON arguments: the arguments field is
|
||||
stored as-is in the message, and orjson failure falls back to {}
|
||||
in the transcript content_blocks."""
|
||||
messages: list = []
|
||||
builder = self._make_transcript_builder()
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
|
||||
],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
|
||||
]
|
||||
|
||||
_baseline_conversation_updater(
|
||||
messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Should not raise — invalid JSON falls back to {} in transcript
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
|
||||
|
||||
|
||||
class TestCompressSessionMessagesPreservesToolCalls:
|
||||
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
|
||||
|
||||
Compression serialises ChatMessage to dict for ``compress_context`` and
|
||||
reifies the result back to ChatMessage. A regression that drops
|
||||
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
|
||||
list and break downstream tool-execution rounds.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compressed_output_keeps_tool_calls_and_ids(self):
|
||||
# Simulate compression that returns a summary + the most recent
|
||||
# assistant(tool_call) + tool(tool_result) intact.
|
||||
summary = {"role": "system", "content": "prior turns: user asked X"}
|
||||
assistant_with_tc = {
|
||||
"role": "assistant",
|
||||
"content": "calling tool",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"y"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
tool_result = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "tc_abc",
|
||||
"content": "search result",
|
||||
}
|
||||
|
||||
compress_result = CompressResult(
|
||||
messages=[summary, assistant_with_tc, tool_result],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
# Input: messages that should be compressed.
|
||||
input_messages = [
|
||||
ChatMessage(role="user", content="q1"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q":"y"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
tool_call_id="tc_abc",
|
||||
content="search result",
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=compress_result),
|
||||
):
|
||||
compressed = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
# Summary, assistant(tool_calls), tool(tool_call_id).
|
||||
assert len(compressed) == 3
|
||||
# Assistant message must keep its tool_calls intact.
|
||||
assistant_msg = compressed[1]
|
||||
assert assistant_msg.role == "assistant"
|
||||
assert assistant_msg.tool_calls is not None
|
||||
assert len(assistant_msg.tool_calls) == 1
|
||||
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
|
||||
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
|
||||
# Tool-role message must keep tool_call_id for OpenAI linkage.
|
||||
tool_msg = compressed[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.tool_call_id == "tc_abc"
|
||||
assert tool_msg.content == "search result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uncompressed_passthrough_keeps_fields(self):
|
||||
"""When compression is a no-op (was_compacted=False), the original
|
||||
messages must be returned unchanged — including tool_calls."""
|
||||
input_messages = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="c",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
|
||||
]
|
||||
|
||||
noop_result = CompressResult(
|
||||
messages=[], # ignored when was_compacted=False
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=noop_result),
|
||||
):
|
||||
out = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
assert out is input_messages # same list returned
|
||||
assert out[0].tool_calls is not None
|
||||
assert out[0].tool_calls[0]["id"] == "t1"
|
||||
assert out[1].tool_call_id == "t1"
|
||||
|
||||
|
||||
# ---- _ThinkingStripper tests ---- #
|
||||
|
||||
|
||||
def test_thinking_stripper_basic_thinking_tag() -> None:
|
||||
"""<thinking>...</thinking> blocks are fully stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert s.process("<thinking>internal reasoning here</thinking>Hello!") == "Hello!"
|
||||
|
||||
|
||||
def test_thinking_stripper_internal_reasoning_tag() -> None:
|
||||
"""<internal_reasoning>...</internal_reasoning> blocks (Gemini) are stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert (
|
||||
s.process("<internal_reasoning>step by step</internal_reasoning>Answer")
|
||||
== "Answer"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_stripper_split_across_chunks() -> None:
|
||||
"""Tags split across multiple chunks are handled correctly."""
|
||||
s = _ThinkingStripper()
|
||||
out = s.process("Hello <thin")
|
||||
out += s.process("king>secret</thinking> world")
|
||||
assert out == "Hello world"
|
||||
|
||||
|
||||
def test_thinking_stripper_plain_text_preserved() -> None:
|
||||
"""Plain text with the word 'thinking' is not stripped."""
|
||||
s = _ThinkingStripper()
|
||||
assert (
|
||||
s.process("I am thinking about this problem")
|
||||
== "I am thinking about this problem"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_stripper_multiple_blocks() -> None:
|
||||
"""Multiple reasoning blocks in one stream are all stripped."""
|
||||
s = _ThinkingStripper()
|
||||
result = s.process(
|
||||
"A<thinking>x</thinking>B<internal_reasoning>y</internal_reasoning>C"
|
||||
)
|
||||
assert result == "ABC"
|
||||
|
||||
|
||||
def test_thinking_stripper_flush_discards_unclosed() -> None:
|
||||
"""Unclosed reasoning block is discarded on flush."""
|
||||
s = _ThinkingStripper()
|
||||
s.process("Start<thinking>never closed")
|
||||
flushed = s.flush()
|
||||
assert "never closed" not in flushed
|
||||
|
||||
|
||||
def test_thinking_stripper_empty_block() -> None:
|
||||
"""Empty reasoning blocks are handled gracefully."""
|
||||
s = _ThinkingStripper()
|
||||
assert s.process("Before<thinking></thinking>After") == "BeforeAfter"
|
||||
|
||||
|
||||
# ---- _filter_tools_by_permissions tests ---- #
|
||||
|
||||
|
||||
def _make_tool(name: str) -> ChatCompletionToolParam:
|
||||
"""Build a minimal OpenAI ChatCompletionToolParam."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={"name": name, "parameters": {}},
|
||||
)
|
||||
|
||||
|
||||
class TestFilterToolsByPermissions:
|
||||
"""Tests for _filter_tools_by_permissions."""
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_empty_permissions_returns_all(self, _mock_names):
|
||||
"""Empty permissions (no filtering) returns every tool unchanged."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [_make_tool("run_block"), _make_tool("web_fetch")]
|
||||
perms = CopilotPermissions()
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
assert result == tools
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_allowlist_keeps_only_matching(self, _mock_names):
|
||||
"""Explicit allowlist (tools_exclude=False) keeps only listed tools."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [
|
||||
_make_tool("run_block"),
|
||||
_make_tool("web_fetch"),
|
||||
_make_tool("bash_exec"),
|
||||
]
|
||||
perms = CopilotPermissions(tools=["web_fetch"], tools_exclude=False)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "web_fetch"
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_blacklist_excludes_listed(self, _mock_names):
|
||||
"""Blacklist (tools_exclude=True) removes only the listed tools."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [
|
||||
_make_tool("run_block"),
|
||||
_make_tool("web_fetch"),
|
||||
_make_tool("bash_exec"),
|
||||
]
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
names = [t["function"]["name"] for t in result]
|
||||
assert "bash_exec" not in names
|
||||
assert "run_block" in names
|
||||
assert "web_fetch" in names
|
||||
assert len(result) == 2
|
||||
|
||||
@patch(
|
||||
"backend.copilot.permissions.all_known_tool_names",
|
||||
return_value=frozenset({"run_block", "web_fetch", "bash_exec"}),
|
||||
)
|
||||
def test_unknown_tool_name_filtered_out(self, _mock_names):
|
||||
"""A tool whose name is not in all_known_tool_names is dropped."""
|
||||
from backend.copilot.baseline.service import _filter_tools_by_permissions
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
|
||||
tools = [_make_tool("run_block"), _make_tool("unknown_tool")]
|
||||
perms = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
result = _filter_tools_by_permissions(tools, perms)
|
||||
names = [t["function"]["name"] for t in result]
|
||||
assert "unknown_tool" not in names
|
||||
assert names == ["run_block"]
|
||||
|
||||
|
||||
# ---- _prepare_baseline_attachments tests ---- #
|
||||
|
||||
|
||||
class TestPrepareBaselineAttachments:
|
||||
"""Tests for _prepare_baseline_attachments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file_ids(self):
|
||||
"""Empty file_ids returns empty hint and blocks."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
hint, blocks = await _prepare_baseline_attachments([], "user1", "sess1", "/tmp")
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_user_id(self):
|
||||
"""Empty user_id returns empty hint and blocks."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["file1"], "", "sess1", "/tmp"
|
||||
)
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_file_returns_vision_blocks(self):
|
||||
"""A PNG image within size limits is returned as a base64 vision block."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_info = AsyncMock()
|
||||
fake_info.name = "photo.png"
|
||||
fake_info.mime_type = "image/png"
|
||||
fake_info.size_bytes = 1024
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
|
||||
fake_manager.read_file_by_id = AsyncMock(return_value=b"\x89PNG_FAKE_DATA")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", "/tmp/workdir"
|
||||
)
|
||||
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["type"] == "image"
|
||||
assert blocks[0]["source"]["media_type"] == "image/png"
|
||||
assert blocks[0]["source"]["type"] == "base64"
|
||||
assert "photo.png" in hint
|
||||
assert "embedded as image" in hint
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_file_saved_to_working_dir(self, tmp_path):
|
||||
"""A non-image file is written to working_dir."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_info = AsyncMock()
|
||||
fake_info.name = "data.csv"
|
||||
fake_info.mime_type = "text/csv"
|
||||
fake_info.size_bytes = 42
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=fake_info)
|
||||
fake_manager.read_file_by_id = AsyncMock(return_value=b"col1,col2\na,b")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", str(tmp_path)
|
||||
)
|
||||
|
||||
assert blocks == []
|
||||
assert "data.csv" in hint
|
||||
assert "saved to" in hint
|
||||
saved = tmp_path / "data.csv"
|
||||
assert saved.exists()
|
||||
assert saved.read_bytes() == b"col1,col2\na,b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found_skipped(self):
|
||||
"""When get_file_info returns None the file is silently skipped."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
fake_manager = AsyncMock()
|
||||
fake_manager.get_file_info = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(return_value=fake_manager),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["missing_id"], "user1", "sess1", "/tmp"
|
||||
)
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_manager_error(self):
|
||||
"""When get_workspace_manager raises, returns empty results."""
|
||||
from backend.copilot.baseline.service import _prepare_baseline_attachments
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.get_workspace_manager",
|
||||
new=AsyncMock(side_effect=RuntimeError("connection failed")),
|
||||
):
|
||||
hint, blocks = await _prepare_baseline_attachments(
|
||||
["fid1"], "user1", "sess1", "/tmp"
|
||||
)
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
@@ -0,0 +1,667 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
|
||||
import json as stdlib_json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
def _make_transcript_content(*roles: str) -> str:
|
||||
"""Build a minimal valid JSONL transcript from role names."""
|
||||
lines = []
|
||||
parent = ""
|
||||
for i, role in enumerate(roles):
|
||||
uid = f"uuid-{i}"
|
||||
entry: dict = {
|
||||
"type": role,
|
||||
"uuid": uid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": f"{role} message {i}"}],
|
||||
},
|
||||
}
|
||||
if role == "assistant":
|
||||
entry["message"]["id"] = f"msg_{i}"
|
||||
entry["message"]["model"] = "test-model"
|
||||
entry["message"]["type"] = "message"
|
||||
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
|
||||
lines.append(stdlib_json.dumps(entry))
|
||||
parent = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
def test_fast_mode_selects_fast_model(self):
|
||||
assert _resolve_baseline_model("fast") == config.fast_model
|
||||
|
||||
def test_extended_thinking_selects_default_model(self):
|
||||
assert _resolve_baseline_model("extended_thinking") == config.model
|
||||
|
||||
def test_none_mode_selects_default_model(self):
|
||||
"""Critical: baseline users without a mode MUST keep the default (opus)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
|
||||
def test_default_and_fast_models_differ(self):
|
||||
"""Sanity: the two tiers are actually distinct in production config."""
|
||||
assert config.model != config.fast_model
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=0,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
class TestUploadFinalTranscript:
|
||||
"""``_upload_final_transcript`` serialises and calls storage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uploads_valid_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
call_kwargs = upload_mock.await_args.kwargs
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=0,
|
||||
)
|
||||
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_upload_exceptions(self):
|
||||
"""Upload failures should not propagate (flow continues for the user)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
|
||||
):
|
||||
# Should not raise.
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
|
||||
class TestRecordTurnToTranscript:
|
||||
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
|
||||
|
||||
def test_records_final_assistant_text(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text="hello there",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "hello there" in jsonl
|
||||
assert STOP_REASON_END_TURN in jsonl
|
||||
|
||||
def test_records_tool_use_then_tool_result(self):
|
||||
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="use a tool")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
|
||||
],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# user, assistant(tool_use), user(tool_result) = 3 entries
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert STOP_REASON_TOOL_USE in jsonl
|
||||
assert "tool_use" in jsonl
|
||||
assert "tool_result" in jsonl
|
||||
assert "call-1" in jsonl
|
||||
|
||||
def test_records_nothing_on_empty_response(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_malformed_tool_args_dont_crash(self):
|
||||
"""Bad JSON in tool arguments falls back to {} without raising."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""End-to-end: load prior → append new turn → upload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# New user turn.
|
||||
builder.append_user(content="new question")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# New assistant turn.
|
||||
response = LLMLoopResponse(
|
||||
response_text="new answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Upload.
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
"""Backfill only runs when the last entry is not already assistant."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
# Simulate the backfill guard from stream_chat_completion_baseline.
|
||||
assistant_text = "partial text before error"
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": assistant_text}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "partial text before error" in builder.to_jsonl()
|
||||
|
||||
# Second invocation: the guard must prevent double-append.
|
||||
initial_count = builder.entry_count
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "duplicate"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
def test_upload_allowed_for_user_with_coverage(self):
|
||||
assert should_upload_transcript("user-1", True) is True
|
||||
|
||||
def test_upload_skipped_when_no_user(self):
|
||||
assert should_upload_transcript(None, True) is False
|
||||
|
||||
def test_upload_skipped_when_empty_user(self):
|
||||
assert should_upload_transcript("", True) is False
|
||||
|
||||
def test_upload_skipped_without_coverage(self):
|
||||
"""Partial transcript must never clobber a more complete stored one."""
|
||||
assert should_upload_transcript("user-1", False) is False
|
||||
|
||||
def test_upload_skipped_when_no_user_and_no_coverage(self):
|
||||
assert should_upload_transcript(None, False) is False
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
driving each step through the real helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
|
||||
# --- 2. Append a new user turn + a new assistant response ---
|
||||
builder.append_user(content="follow-up question")
|
||||
_record_turn_to_transcript(
|
||||
LLMLoopResponse(
|
||||
response_text="follow-up answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
),
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=stale),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
"""Anonymous (user_id=None) → upload gate must return False."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
@@ -8,18 +8,35 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
# Per-request routing mode for a single chat turn.
|
||||
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
|
||||
# - 'extended_thinking': route to the Claude Agent SDK path with the default
|
||||
# (opus) model.
|
||||
# ``None`` means "no override"; the server falls back to the Claude Code
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6",
|
||||
description="Default model for extended thinking mode",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
@@ -77,11 +94,11 @@ class ChatConfig(BaseSettings):
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# ENTERPRISE) multiply these by their tier multiplier (see
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# User.subscriptionTier DB column and resolved inside
|
||||
# get_global_rate_limits().
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
@@ -91,6 +108,20 @@ class ChatConfig(BaseSettings):
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Cost (in credits / cents) to reset the daily rate limit using credits.
|
||||
# When a user hits their daily limit, they can spend this amount to reset
|
||||
# the daily counter and keep working. Set to 0 to disable the feature.
|
||||
rate_limit_reset_cost: int = Field(
|
||||
default=500,
|
||||
ge=0,
|
||||
description="Credit cost (in cents) for resetting the daily rate limit. 0 = disabled.",
|
||||
)
|
||||
max_daily_resets: int = Field(
|
||||
default=5,
|
||||
ge=0,
|
||||
description="Maximum number of credit-based rate limit resets per user per day. 0 = unlimited.",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
@@ -164,7 +195,7 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
present — mirrors the fallback logic in ``build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user