mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
143 Commits
feat/trans
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95c6907ccd | ||
|
|
f4bc3c2012 | ||
|
|
f265ef8ac3 | ||
|
|
c79e6ff30a | ||
|
|
7db8bf161a | ||
|
|
84650d0f4d | ||
|
|
0467cb2e49 | ||
|
|
24d0c35ed3 | ||
|
|
8aae7751dc | ||
|
|
725da7e887 | ||
|
|
bd9e9ec614 | ||
|
|
88589764b5 | ||
|
|
c659f3b058 | ||
|
|
80581a8364 | ||
|
|
3c046eb291 | ||
|
|
3e25488b2d | ||
|
|
57b17dc8e1 | ||
|
|
a20188ae59 | ||
|
|
c410be890e | ||
|
|
37d9863552 | ||
|
|
2f42ff9b47 | ||
|
|
914efc53e5 | ||
|
|
17e78ca382 | ||
|
|
7ba05366ed | ||
|
|
ca74f980c1 | ||
|
|
68f5d2ad08 | ||
|
|
2b3d730ca9 | ||
|
|
f28628e34b | ||
|
|
b6a027fd2b | ||
|
|
fb74fcf4a4 | ||
|
|
28b26dde94 | ||
|
|
d677978c90 | ||
|
|
a347c274b7 | ||
|
|
f79d8f0449 | ||
|
|
1bc48c55d5 | ||
|
|
9d0a31c0f1 | ||
|
|
9b086e39c6 | ||
|
|
5867e4d613 | ||
|
|
85f0d8353a | ||
|
|
f871717f68 | ||
|
|
f08e52dc86 | ||
|
|
500b345b3b | ||
|
|
995dd1b5f3 | ||
|
|
336114f217 | ||
|
|
866563ad25 | ||
|
|
e79928a815 | ||
|
|
1771ed3bef | ||
|
|
550fa5a319 | ||
|
|
8528dffbf2 | ||
|
|
8fbf6a4b09 | ||
|
|
239148596c | ||
|
|
a880d73481 | ||
|
|
80bfd64ffa | ||
|
|
0076ad2a1a | ||
|
|
edb3d322f0 | ||
|
|
9381057079 | ||
|
|
f21a36ca37 | ||
|
|
ee5382a064 | ||
|
|
b80e5ea987 | ||
|
|
3d4fcfacb6 | ||
|
|
32eac6d52e | ||
|
|
9762f4cde7 | ||
|
|
76901ba22f | ||
|
|
23b65939f3 | ||
|
|
1c27eaac53 | ||
|
|
923b164794 | ||
|
|
e86ac21c43 | ||
|
|
94224be841 | ||
|
|
da4bdc7ab9 | ||
|
|
7176cecf25 | ||
|
|
f35210761c | ||
|
|
1ebcf85669 | ||
|
|
ab7c38bda7 | ||
|
|
0f67e45d05 | ||
|
|
b9ce37600e | ||
|
|
3921deaef1 | ||
|
|
f01f668674 | ||
|
|
f7a3491f91 | ||
|
|
cbff3b53d3 | ||
|
|
5b9a4c52c9 | ||
|
|
0ce1c90b55 | ||
|
|
d4c6eb9adc | ||
|
|
1bb91b53b7 | ||
|
|
a5f9c43a41 | ||
|
|
1240f38f75 | ||
|
|
f617f50f0b | ||
|
|
943a1df815 | ||
|
|
593001e0c8 | ||
|
|
e1db8234a3 | ||
|
|
282173be9d | ||
|
|
5d9a169e04 | ||
|
|
6fd1050457 | ||
|
|
02708bcd00 | ||
|
|
156d61fe5c | ||
|
|
5a29de0e0e | ||
|
|
e657472162 | ||
|
|
4d00e0f179 | ||
|
|
1d7282b5f3 | ||
|
|
e3591fcaa3 | ||
|
|
876dc32e17 | ||
|
|
616e29f5e4 | ||
|
|
280a98ad38 | ||
|
|
c7f2a7dd03 | ||
|
|
6d0e2063ec | ||
|
|
8b577ae194 | ||
|
|
d8f5f783ae | ||
|
|
82d22f3680 | ||
|
|
50622333d1 | ||
|
|
27af5782a9 | ||
|
|
522f932e67 | ||
|
|
a6124b06d5 | ||
|
|
ae660ea04f | ||
|
|
2479f3a1c4 | ||
|
|
8153306384 | ||
|
|
9c3d100a22 | ||
|
|
fc3bf6c154 | ||
|
|
e32d258a7e | ||
|
|
3e86544bfe | ||
|
|
c6b729bdfa | ||
|
|
7a391fbd99 | ||
|
|
791dd7cb48 | ||
|
|
f0800b9420 | ||
|
|
60bc49ba50 | ||
|
|
ba4f4b6242 | ||
|
|
8892bcd230 | ||
|
|
48ff8300a4 | ||
|
|
c268fc6464 | ||
|
|
aff3fb44af | ||
|
|
9a41312769 | ||
|
|
048fb06b0a | ||
|
|
3f653e6614 | ||
|
|
c9c3d54b2b | ||
|
|
53d58e21d3 | ||
|
|
fa04fb41d8 | ||
|
|
d9c16ded65 | ||
|
|
6dc8429ae7 | ||
|
|
cfe22e5a8f | ||
|
|
a8259ca935 | ||
|
|
1f1288d623 | ||
|
|
02645732b8 | ||
|
|
ba301a3912 | ||
|
|
0cd9c0d87a | ||
|
|
a083493aa2 |
1
.agents/skills
Symbolic link
1
.agents/skills
Symbolic link
@@ -0,0 +1 @@
|
||||
../.claude/skills
|
||||
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
@@ -1,35 +0,0 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
@@ -1,16 +0,0 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
106
.claude/skills/open-pr/SKILL.md
Normal file
106
.claude/skills/open-pr/SKILL.md
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
name: open-pr
|
||||
description: Open a pull request with proper PR template, test coverage, and review workflow. Guides agents through creating a PR that follows repo conventions, ensures existing behaviors aren't broken, covers new behaviors with tests, and handles review via bot when local testing isn't possible. TRIGGER when user asks to "open a PR", "create a PR", "make a PR", "submit a PR", "open pull request", "push and create PR", or any variation of opening/submitting a pull request.
|
||||
user-invocable: true
|
||||
args: "[base-branch] — optional target branch (defaults to dev)."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Open a Pull Request
|
||||
|
||||
## Step 1: Pre-flight checks
|
||||
|
||||
Before opening the PR:
|
||||
|
||||
1. Ensure all changes are committed
|
||||
2. Ensure the branch is pushed to the remote (`git push -u origin <branch>`)
|
||||
3. Run linters/formatters across the whole repo (not just changed files) and commit any fixes
|
||||
|
||||
## Step 2: Test coverage
|
||||
|
||||
**This is critical.** Before opening the PR, verify:
|
||||
|
||||
### Existing behavior is not broken
|
||||
- Identify which modules/components your changes touch
|
||||
- Run the existing test suites for those areas
|
||||
- If tests fail, fix them before opening the PR — do not open a PR with known regressions
|
||||
|
||||
### New behavior has test coverage
|
||||
- Every new feature, endpoint, or behavior change needs tests
|
||||
- If you added a new block, add tests for that block
|
||||
- If you changed API behavior, add or update API tests
|
||||
- If you changed frontend behavior, verify it doesn't break existing flows
|
||||
|
||||
If you cannot run the full test suite locally, note which tests you ran and which you couldn't in the test plan.
|
||||
|
||||
## Step 3: Create the PR using the repo template
|
||||
|
||||
Read the canonical PR template at `.github/PULL_REQUEST_TEMPLATE.md` and use it **verbatim** as your PR body:
|
||||
|
||||
1. Read the template: `cat .github/PULL_REQUEST_TEMPLATE.md`
|
||||
2. Preserve the exact section titles and formatting, including:
|
||||
- `### Why / What / How`
|
||||
- `### Changes 🏗️`
|
||||
- `### Checklist 📋`
|
||||
3. Replace HTML comment prompts (`<!-- ... -->`) with actual content; do not leave them in
|
||||
4. **Do not pre-check boxes** — leave all checkboxes as `- [ ]` until each step is actually completed
|
||||
5. Do not alter the template structure, rename sections, or remove any checklist items
|
||||
|
||||
**PR title must use conventional commit format** (e.g., `feat(backend): add new block`, `fix(frontend): resolve routing bug`, `dx(skills): update PR workflow`). See CLAUDE.md for the full list of scopes.
|
||||
|
||||
Use `gh pr create` with the base branch (defaults to `dev` if no `[base-branch]` was provided). Use `--body-file` to avoid shell interpretation of backticks and special characters:
|
||||
|
||||
```bash
|
||||
BASE_BRANCH="${BASE_BRANCH:-dev}"
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
<filled-in template from .github/PULL_REQUEST_TEMPLATE.md>
|
||||
PREOF
|
||||
gh pr create --base "$BASE_BRANCH" --title "<type>(scope): short description" --body-file "$PR_BODY"
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
|
||||
## Step 4: Review workflow
|
||||
|
||||
### If you have a workspace that allows testing (docker, running backend, etc.)
|
||||
- Run `/pr-test` to do E2E manual testing of the PR using docker compose, agent-browser, and API calls. This is the most thorough way to validate your changes before review.
|
||||
- After testing, run `/pr-review` to self-review the PR for correctness, security, code quality, and testing gaps before requesting human review.
|
||||
|
||||
### If you do NOT have a workspace that allows testing
|
||||
This is common for agents running in worktrees without a full stack. In this case:
|
||||
|
||||
1. Run `/pr-review` locally to catch obvious issues before pushing
|
||||
2. **Comment `/review` on the PR** after creating it to trigger the review bot
|
||||
3. **Poll for the review** rather than blindly waiting — check for new review comments every 30 seconds using `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` and the GraphQL inline threads query. The bot typically responds within 30 minutes, but polling lets the agent react as soon as it arrives.
|
||||
4. Do NOT proceed or merge until the bot review comes back
|
||||
5. Address any issues the bot raises — use `/pr-address` which has a full polling loop with CI + comment tracking
|
||||
|
||||
```bash
|
||||
# After creating the PR:
|
||||
PR_NUMBER=$(gh pr view --json number -q .number)
|
||||
gh pr comment "$PR_NUMBER" --body "/review"
|
||||
# Then use /pr-address to poll for and address the review when it arrives
|
||||
```
|
||||
|
||||
## Step 5: Address review feedback
|
||||
|
||||
Once the review bot or human reviewers leave comments:
|
||||
- Run `/pr-address` to address review comments. It will loop until CI is green and all comments are resolved.
|
||||
- Do not merge without human approval.
|
||||
|
||||
## Related skills
|
||||
|
||||
| Skill | When to use |
|
||||
|---|---|
|
||||
| `/pr-test` | E2E testing with docker compose, agent-browser, API calls — use when you have a running workspace |
|
||||
| `/pr-review` | Review for correctness, security, code quality — use before requesting human review |
|
||||
| `/pr-address` | Address reviewer comments and loop until CI green — use after reviews come in |
|
||||
|
||||
## Step 6: Post-creation
|
||||
|
||||
After the PR is created and review is triggered:
|
||||
- Share the PR URL with the user
|
||||
- If waiting on the review bot, let the user know the expected wait time (~30 min)
|
||||
- Do not merge without human approval
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
210
.claude/skills/pr-address/SKILL.md
Normal file
210
.claude/skills/pr-address/SKILL.md
Normal file
@@ -0,0 +1,210 @@
|
||||
---
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Address
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the PR description
|
||||
|
||||
Understand the **Why / What / How** before addressing comments — you need context to make good fixes:
|
||||
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
### 1. Inline review threads — GraphQL (primary source of actionable items)
|
||||
|
||||
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
|
||||
|
||||
```bash
|
||||
gh api graphql -f query='
|
||||
{
|
||||
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
|
||||
pullRequest(number: {N}) {
|
||||
reviewThreads(first: 100) {
|
||||
pageInfo { hasNextPage endCursor }
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
path
|
||||
comments(last: 1) {
|
||||
nodes { databaseId body author { login } createdAt }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
|
||||
|
||||
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
|
||||
|
||||
### 2. Top-level reviews — REST (MUST paginate)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
|
||||
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80–170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
|
||||
|
||||
Two things to extract:
|
||||
- **Overall state**: look for `CHANGES_REQUESTED` or `APPROVED` reviews.
|
||||
- **Actionable feedback**: non-empty bodies only. Empty-body reviews are thread-resolution events — they indicate progress but have no feedback to act on.
|
||||
|
||||
**Where each reviewer posts:**
|
||||
- `autogpt-reviewer` — posts detailed structured reviews ("Blockers", "Should Fix", "Nice to Have") as **top-level reviews**. Not present on every PR. Address ALL items.
|
||||
- `sentry[bot]` — posts bug predictions as **inline threads**. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — posts summaries as **top-level reviews** AND actionable items as **inline threads**. Address actionable items.
|
||||
- Human reviewers — can post in any source. Address ALL non-empty feedback.
|
||||
|
||||
### 3. PR conversation comments — REST
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
|
||||
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
After fixing, format the changed code:
|
||||
|
||||
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
||||
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
||||
|
||||
If API routes changed, regenerate the frontend client:
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
trap "kill $REST_PID 2>/dev/null" EXIT
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing. Each fix should be visible on GitHub right away so CI can start and reviewers can see progress.
|
||||
|
||||
**Never push empty commits** (`git commit --allow-empty`) to re-trigger CI or bot checks. When a check fails, investigate the root cause (unchecked PR checklist, unaddressed review comments, code issues) and fix those directly. Empty commits add noise to git history.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ wait for CI (while addressing new comments) → fix failures → push
|
||||
→ re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
### Polling for CI + new comments
|
||||
|
||||
After pushing, poll for **both** CI status and new comments in a single loop. Do not use `gh pr checks --watch` — it blocks the tool and prevents reacting to new comments while CI is running.
|
||||
|
||||
> **Note:** `gh pr checks --watch --fail-fast` is tempting but it blocks the entire Bash tool call, meaning the agent cannot check for or address new comments until CI fully completes. Always poll manually instead.
|
||||
|
||||
**Polling loop — repeat every 30 seconds:**
|
||||
|
||||
1. Check CI status:
|
||||
```bash
|
||||
gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,name,link
|
||||
```
|
||||
Parse the results: if every check has `bucket` of `"pass"` or `"skipping"`, CI is green. If any has `"fail"`, CI has failed. Otherwise CI is still pending.
|
||||
|
||||
2. Check for merge conflicts:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.mergeable'
|
||||
```
|
||||
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
|
||||
|
||||
3. Check for new/changed comments (all three sources):
|
||||
|
||||
**Inline threads** — re-run the GraphQL query from "Fetch comments". For each unresolved thread, record `{thread_id, last_comment_databaseId}` as your baseline. On each poll, action is needed if:
|
||||
- A new thread `id` appears that wasn't in the baseline (new thread), OR
|
||||
- An existing thread's `last_comment_databaseId` has changed (new reply on existing thread)
|
||||
|
||||
**Conversation comments:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
|
||||
```
|
||||
Compare total count and newest `id` against baseline. Filter to non-empty, non-bot, non-author-update messages.
|
||||
|
||||
**Top-level reviews:**
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
|
||||
```
|
||||
Watch for new non-empty reviews (`CHANGES_REQUESTED` or `COMMENTED` with body). Compare total count and newest `id` against baseline.
|
||||
|
||||
4. **React in this precedence order (first match wins):**
|
||||
|
||||
| What happened | Action |
|
||||
|---|---|
|
||||
| Merge conflict detected | See "Resolving merge conflicts" below. |
|
||||
| Mergeability is `UNKNOWN` | GitHub is still computing mergeability. Sleep 30 seconds, then restart polling from the top. |
|
||||
| New comments detected | Address them (fix → commit → push → reply). After pushing, re-fetch all comments to update your baseline, then restart this polling loop from the top (new commits invalidate CI status). |
|
||||
| CI failed (bucket == "fail") | Get failed check links: `gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,link --jq '.[] \| select(.bucket == "fail") \| .link'`. Extract run ID from link (format: `.../actions/runs/<run-id>/job/...`), read logs with `gh run view <run-id> --repo Significant-Gravitas/AutoGPT --log-failed`. Fix → commit → push → restart polling. |
|
||||
| CI green + no new comments | **Do not exit immediately.** Bots (coderabbitai, sentry) often post reviews shortly after CI settles. Continue polling for **2 more cycles (60s)** after CI goes green. Only exit after 2 consecutive green+quiet polls. |
|
||||
| CI pending + no new comments | Sleep 30 seconds, then poll again. |
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + **2 consecutive polls with no new comments after CI settled.**
|
||||
|
||||
### Resolving merge conflicts
|
||||
|
||||
1. Identify the PR's target branch and remote:
|
||||
```bash
|
||||
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json baseRefName --jq '.baseRefName'
|
||||
git remote -v # find the remote pointing to Significant-Gravitas/AutoGPT (typically 'upstream' in forks, 'origin' for direct contributors)
|
||||
```
|
||||
|
||||
2. Pull the latest base branch with a 3-way merge:
|
||||
```bash
|
||||
git pull {base-remote} {base-branch} --no-rebase
|
||||
```
|
||||
|
||||
3. Resolve conflicting files, then verify no conflict markers remain:
|
||||
```bash
|
||||
if grep -R -n -E '^(<<<<<<<|=======|>>>>>>>)' <conflicted-files>; then
|
||||
echo "Unresolved conflict markers found — resolve before proceeding."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
4. Stage and push:
|
||||
```bash
|
||||
git add <conflicted-files>
|
||||
git commit -m "Resolve merge conflicts with {base-branch}"
|
||||
git push
|
||||
```
|
||||
|
||||
5. Restart the polling loop from the top — new commits reset CI status.
|
||||
@@ -1,31 +0,0 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
@@ -1,51 +1,86 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review Comment Workflow
|
||||
# PR Review
|
||||
|
||||
## Steps
|
||||
## Find the PR
|
||||
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## CRITICAL: Do Not Stop
|
||||
## Read the PR description
|
||||
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
Before reading code, understand the **why**, **what**, and **how** from the PR description:
|
||||
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
```bash
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
```
|
||||
|
||||
## Rules
|
||||
Every PR should have a Why / What / How structure. If any of these are missing, note it as feedback.
|
||||
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
gh pr diff {N}
|
||||
```
|
||||
|
||||
## Fetch existing review comments
|
||||
|
||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
```
|
||||
|
||||
## What to check
|
||||
|
||||
**Description quality:** Does the PR description cover Why (motivation/problem), What (summary of changes), and How (approach/implementation details)? If any are missing, request them — you can't judge the approach without understanding the problem and intent.
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
||||
|
||||
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
||||
|
||||
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
||||
|
||||
## Output format
|
||||
|
||||
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
||||
|
||||
| Tier | Badge | Meaning |
|
||||
|---|---|---|
|
||||
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
||||
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
||||
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
||||
| Nit | `🔵 **Nit**` | Style / wording |
|
||||
|
||||
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
||||
|
||||
## Post inline comments
|
||||
|
||||
For each finding, post an inline comment on the PR (do not just write a local report):
|
||||
|
||||
```bash
|
||||
# Get the latest commit SHA for the PR
|
||||
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
||||
|
||||
# Post an inline comment on a specific file/line
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
||||
-f body="🤖 🔴 **Blocker**: <description>" \
|
||||
-f commit_id="$COMMIT_SHA" \
|
||||
-f path="<file path>" \
|
||||
-F line=<line number>
|
||||
```
|
||||
|
||||
754
.claude/skills/pr-test/SKILL.md
Normal file
754
.claude/skills/pr-test/SKILL.md
Normal file
@@ -0,0 +1,754 @@
|
||||
---
|
||||
name: pr-test
|
||||
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
|
||||
user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "2.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Critical Requirements
|
||||
|
||||
These are NON-NEGOTIABLE. Every test run MUST satisfy ALL the following:
|
||||
|
||||
### 1. Screenshots at Every Step
|
||||
- Take a screenshot at EVERY significant test step — not just at the end
|
||||
- Every test scenario MUST have at least one BEFORE and one AFTER screenshot
|
||||
- Name screenshots sequentially: `{NN}-{action}-{state}.png` (e.g., `01-credits-before.png`, `02-credits-after.png`)
|
||||
- If a screenshot is missing for a scenario, the test is INCOMPLETE — go back and take it
|
||||
|
||||
### 2. Screenshots MUST Be Posted to PR
|
||||
- Push ALL screenshots to a temp branch `test-screenshots/pr-{N}`
|
||||
- Post a PR comment with ALL screenshots embedded inline using GitHub raw URLs
|
||||
- This is NOT optional — every test run MUST end with a PR comment containing screenshots
|
||||
- If screenshot upload fails, retry. If it still fails, list failed files and require manual drag-and-drop/paste attachment in the PR comment
|
||||
|
||||
### 3. State Verification with Before/After Evidence
|
||||
- For EVERY state-changing operation (API call, user action), capture the state BEFORE and AFTER
|
||||
- Log the actual API response values (e.g., `credits_before=100, credits_after=95`)
|
||||
- Screenshot MUST show the relevant UI state change
|
||||
- Compare expected vs actual values explicitly — do not just eyeball it
|
||||
|
||||
### 4. Negative Test Cases Are Mandatory
|
||||
- Test at least ONE negative case per feature (e.g., insufficient credits, invalid input, unauthorized access)
|
||||
- Verify error messages are user-friendly and accurate
|
||||
- Verify the system state did NOT change after a rejected operation
|
||||
|
||||
### 5. Test Report Must Include Full Evidence
|
||||
Each test scenario in the report MUST have:
|
||||
- **Steps**: What was done (exact commands or UI actions)
|
||||
- **Expected**: What should happen
|
||||
- **Actual**: What actually happened
|
||||
- **API Evidence**: Before/after API response values for state-changing operations
|
||||
- **Screenshot Evidence**: Before/after screenshots with explanations
|
||||
|
||||
## State Manipulation for Realistic Testing
|
||||
|
||||
When testing features that depend on specific states (rate limits, credits, quotas):
|
||||
|
||||
1. **Use Redis CLI to set counters directly:**
|
||||
```bash
|
||||
# Find the Redis container
|
||||
REDIS_CONTAINER=$(docker ps --format '{{.Names}}' | grep redis | head -1)
|
||||
# Set a key with expiry
|
||||
docker exec $REDIS_CONTAINER redis-cli SET key value EX ttl
|
||||
# Example: Set rate limit counter to near-limit
|
||||
docker exec $REDIS_CONTAINER redis-cli SET "rate_limit:user:test@test.com" 99 EX 3600
|
||||
# Example: Check current value
|
||||
docker exec $REDIS_CONTAINER redis-cli GET "rate_limit:user:test@test.com"
|
||||
```
|
||||
|
||||
2. **Use API calls to check before/after state:**
|
||||
```bash
|
||||
# BEFORE: Record current state
|
||||
BEFORE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits BEFORE: $BEFORE"
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER: Record new state and compare
|
||||
AFTER=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
echo "Credits AFTER: $AFTER"
|
||||
echo "Delta: $(( BEFORE - AFTER ))"
|
||||
```
|
||||
|
||||
3. **Take screenshots BEFORE and AFTER state changes** — the UI must reflect the backend state change
|
||||
|
||||
4. **Never rely on mocked/injected browser state** — always use real backend state. Do NOT use `agent-browser eval` to fake UI state. The backend must be the source of truth.
|
||||
|
||||
5. **Use direct DB queries when needed:**
|
||||
```bash
|
||||
# Query via Supabase's PostgREST or docker exec into the DB
|
||||
docker exec supabase-db psql -U supabase_admin -d postgres -c "SELECT credits FROM user_credits WHERE user_id = '...';"
|
||||
```
|
||||
|
||||
6. **After every API test, verify the state change actually persisted:**
|
||||
```bash
|
||||
# Example: After a credits purchase, verify DB matches API
|
||||
API_CREDITS=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/credits | jq '.credits')
|
||||
DB_CREDITS=$(docker exec supabase-db psql -U supabase_admin -d postgres -t -c "SELECT credits FROM user_credits WHERE user_id = '...';" | tr -d ' ')
|
||||
[ "$API_CREDITS" = "$DB_CREDITS" ] && echo "CONSISTENT" || echo "MISMATCH: API=$API_CREDITS DB=$DB_CREDITS"
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
|
||||
|
||||
## Step 0: Resolve the target
|
||||
|
||||
```bash
|
||||
# If argument is a PR number, find its worktree
|
||||
gh pr view {N} --json headRefName --jq '.headRefName'
|
||||
# If argument is a path, use it directly
|
||||
```
|
||||
|
||||
Determine:
|
||||
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
|
||||
- `WORKTREE_PATH` — the worktree directory
|
||||
- `PLATFORM_DIR` — `$WORKTREE_PATH/autogpt_platform`
|
||||
- `BACKEND_DIR` — `$PLATFORM_DIR/backend`
|
||||
- `FRONTEND_DIR` — `$PLATFORM_DIR/frontend`
|
||||
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
|
||||
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
|
||||
- `RESULTS_DIR` — `$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
|
||||
|
||||
Create the results directory:
|
||||
```bash
|
||||
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
|
||||
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
|
||||
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
|
||||
mkdir -p $RESULTS_DIR
|
||||
```
|
||||
|
||||
**Test user credentials** (for logging into the UI or verifying results manually):
|
||||
- Email: `test@test.com`
|
||||
- Password: `testtest123`
|
||||
|
||||
## Step 1: Understand the PR
|
||||
|
||||
Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
|
||||
# Read PR description to understand the WHY
|
||||
gh pr view {N} --json body --jq '.body'
|
||||
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the PR description (Why / What / How) and changed files to understand:
|
||||
0. **Why** does this PR exist? What problem does it solve?
|
||||
1. **What** feature/fix does this PR implement?
|
||||
2. **How** does it work? What's the approach?
|
||||
3. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
4. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
```markdown
|
||||
# Test Plan: PR #{N} — {title}
|
||||
|
||||
## Scenarios
|
||||
1. [Scenario name] — [what to verify]
|
||||
2. ...
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
- Before state: [what to check before]
|
||||
- After state: [what to verify changed]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
- Screenshot before: [what to capture]
|
||||
- Screenshot after: [what to capture]
|
||||
|
||||
## Negative Tests (REQUIRED — at least one per feature)
|
||||
1. [What should NOT happen] — [how to trigger it]
|
||||
- Expected error: [what error message/code]
|
||||
- State unchanged: [what to verify did NOT change]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks. Every scenario MUST specify what screenshots to take and what state to verify.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
|
||||
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
|
||||
|
||||
```bash
|
||||
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
|
||||
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3b. Configure copilot authentication
|
||||
|
||||
The copilot needs an LLM API to function. Two approaches (try subscription first):
|
||||
|
||||
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
|
||||
|
||||
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
|
||||
|
||||
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
|
||||
|
||||
```bash
|
||||
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
|
||||
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
**How it works:** The script reads the OAuth token from:
|
||||
- **macOS**: system keychain (`"Claude Code-credentials"`)
|
||||
- **Linux/WSL**: `~/.claude/.credentials.json`
|
||||
- **Windows**: `%APPDATA%/claude/.credentials.json`
|
||||
|
||||
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
|
||||
|
||||
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
|
||||
|
||||
#### Option 2: OpenRouter API key mode (fallback)
|
||||
|
||||
If subscription mode doesn't work, switch to API key mode using OpenRouter:
|
||||
|
||||
```bash
|
||||
# In $BACKEND_DIR/.env, ensure these are set:
|
||||
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
|
||||
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
|
||||
CHAT_BASE_URL=https://openrouter.ai/api/v1
|
||||
CHAT_USE_CLAUDE_AGENT_SDK=true
|
||||
```
|
||||
|
||||
Use `sed` to update these values:
|
||||
```bash
|
||||
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
|
||||
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
|
||||
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
|
||||
# Add or update CHAT_API_KEY and CHAT_BASE_URL
|
||||
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
|
||||
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3c. Stop conflicting containers
|
||||
|
||||
```bash
|
||||
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
|
||||
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
|
||||
docker stop "$name" 2>/dev/null
|
||||
done
|
||||
```
|
||||
|
||||
### 3e. Build and start
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
|
||||
|
||||
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
|
||||
```
|
||||
|
||||
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
|
||||
|
||||
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
|
||||
|
||||
### 3f. Wait for services to be ready
|
||||
|
||||
```bash
|
||||
# Poll until backend and frontend respond
|
||||
for i in $(seq 1 60); do
|
||||
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
|
||||
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
|
||||
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
|
||||
echo "Services ready"
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
### 3h. Create test user and get auth token
|
||||
|
||||
```bash
|
||||
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
|
||||
|
||||
# Signup (idempotent — returns "User already registered" if exists)
|
||||
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}')
|
||||
|
||||
# If "Database error finding user", restart supabase-auth and retry
|
||||
if echo "$RESULT" | grep -q "Database error"; then
|
||||
docker restart supabase-auth && sleep 5
|
||||
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}'
|
||||
fi
|
||||
|
||||
# Get auth token
|
||||
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
|
||||
```
|
||||
|
||||
**Use this token for ALL API calls:**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
|
||||
| Service | Port | URL |
|
||||
|---------|------|-----|
|
||||
| Frontend | 3000 | http://localhost:3000 |
|
||||
| Backend REST | 8006 | http://localhost:8006 |
|
||||
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
|
||||
| Executor | 8002 | http://localhost:8002 |
|
||||
| Copilot Executor | 8008 | http://localhost:8008 |
|
||||
| WebSocket | 8001 | http://localhost:8001 |
|
||||
| Database Manager | 8005 | http://localhost:8005 |
|
||||
| Redis | 6379 | localhost:6379 |
|
||||
| RabbitMQ | 5672 | localhost:5672 |
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests. **For EVERY API call that changes state, record before/after values:**
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
|
||||
|
||||
# Example: Create an agent
|
||||
curl -s -X POST http://localhost:8006/api/graphs \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{...}' | jq .
|
||||
|
||||
# Example: Run an agent
|
||||
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"data": {...}}'
|
||||
|
||||
# Example: Get execution results
|
||||
curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
**State verification pattern (use for EVERY state-changing API call):**
|
||||
```bash
|
||||
# 1. Record BEFORE state
|
||||
BEFORE_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "BEFORE: $BEFORE_STATE"
|
||||
|
||||
# 2. Perform the action
|
||||
ACTION_RESULT=$(curl -s -X POST ... | jq .)
|
||||
echo "ACTION RESULT: $ACTION_RESULT"
|
||||
|
||||
# 3. Record AFTER state
|
||||
AFTER_STATE=$(curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/{resource} | jq '{relevant_fields}')
|
||||
echo "AFTER: $AFTER_STATE"
|
||||
|
||||
# 4. Log the comparison
|
||||
echo "=== STATE CHANGE VERIFICATION ==="
|
||||
echo "Before: $BEFORE_STATE"
|
||||
echo "After: $AFTER_STATE"
|
||||
echo "Expected change: {describe what should have changed}"
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
# Close any existing session
|
||||
agent-browser close 2>/dev/null || true
|
||||
|
||||
# Use --session-name to persist cookies across navigations
|
||||
# This means login only needs to happen once per test session
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
|
||||
|
||||
# Get interactive elements
|
||||
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
|
||||
|
||||
# Login
|
||||
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
|
||||
agent-browser --session-name pr-test fill {password_ref} "testtest123"
|
||||
agent-browser --session-name pr-test click {login_button_ref}
|
||||
sleep 5
|
||||
|
||||
# Dismiss cookie banner if present
|
||||
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
|
||||
|
||||
# Navigate — cookies are preserved so login persists
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
|
||||
# Take screenshot
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
|
||||
|
||||
# Interact with elements
|
||||
agent-browser --session-name pr-test fill {ref} "text"
|
||||
agent-browser --session-name pr-test press "Enter"
|
||||
agent-browser --session-name pr-test click {ref}
|
||||
agent-browser --session-name pr-test click 'text=Button Text'
|
||||
|
||||
# Read page content
|
||||
agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
```
|
||||
|
||||
**Key pages:**
|
||||
- `/copilot` — CoPilot chat (for testing copilot features)
|
||||
- `/build` — Agent builder (for testing block/node features)
|
||||
- `/build?flowID={id}` — Specific agent in builder
|
||||
- `/library` — Agent library (for testing listing/import features)
|
||||
- `/library/agents/{id}` — Agent detail with run history
|
||||
- `/marketplace` — Marketplace
|
||||
|
||||
### Checking logs
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
|
||||
# Executor (runs agent graphs)
|
||||
docker logs autogpt_platform-executor-1 2>&1 | tail -30
|
||||
|
||||
# Copilot executor (runs copilot chat sessions)
|
||||
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
|
||||
|
||||
# Frontend
|
||||
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
|
||||
|
||||
# Filter for errors
|
||||
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
|
||||
```
|
||||
|
||||
### Copilot chat testing
|
||||
|
||||
The copilot uses SSE streaming. To test via API:
|
||||
|
||||
```bash
|
||||
# Create a session
|
||||
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}' | jq -r '.id // .session_id // ""')
|
||||
|
||||
# Stream a message (SSE - will stream chunks)
|
||||
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"message": "Hello, what can you help me with?"}' \
|
||||
--max-time 60 2>/dev/null | head -50
|
||||
```
|
||||
|
||||
Or test via browser (preferred for UI verification):
|
||||
```bash
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
# ... fill chat input and press Enter, wait 20-30s for response
|
||||
```
|
||||
|
||||
## Step 5: Record results and take screenshots
|
||||
|
||||
**Take a screenshot at EVERY significant test step** — before and after interactions, on success, and on failure. This is NON-NEGOTIABLE.
|
||||
|
||||
**Required screenshot pattern for each test scenario:**
|
||||
```bash
|
||||
# BEFORE the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-before.png
|
||||
|
||||
# Perform the action...
|
||||
|
||||
# AFTER the action
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{scenario}-after.png
|
||||
```
|
||||
|
||||
**Naming convention:**
|
||||
```bash
|
||||
# Examples:
|
||||
# $RESULTS_DIR/01-login-page-before.png
|
||||
# $RESULTS_DIR/02-login-page-after.png
|
||||
# $RESULTS_DIR/03-credits-page-before.png
|
||||
# $RESULTS_DIR/04-credits-purchase-after.png
|
||||
# $RESULTS_DIR/05-negative-insufficient-credits.png
|
||||
# $RESULTS_DIR/06-error-state.png
|
||||
```
|
||||
|
||||
**Minimum requirements:**
|
||||
- At least TWO screenshots per test scenario (before + after)
|
||||
- At least ONE screenshot for each negative test case showing the error state
|
||||
- If a test fails, screenshot the failure state AND any error logs visible in the UI
|
||||
|
||||
## Step 6: Show results to user with screenshots
|
||||
|
||||
**CRITICAL: After all tests complete, you MUST show every screenshot to the user using the Read tool, with an explanation of what each screenshot shows.** This is the most important part of the test report — the user needs to visually verify the results.
|
||||
|
||||
For each screenshot:
|
||||
1. Use the `Read` tool to display the PNG file (Claude can read images)
|
||||
2. Write a 1-2 sentence explanation below it describing:
|
||||
- What page/state is being shown
|
||||
- What the screenshot proves (which test scenario it validates)
|
||||
- Any notable details visible in the UI
|
||||
|
||||
Format the output like this:
|
||||
|
||||
```markdown
|
||||
### Screenshot 1: {descriptive title}
|
||||
[Read the PNG file here]
|
||||
|
||||
**What it shows:** {1-2 sentence explanation of what this screenshot proves}
|
||||
|
||||
---
|
||||
```
|
||||
|
||||
After showing all screenshots, output a **detailed** summary table:
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
| 1 | {name} | PASS/FAIL | Before: X, After: Y | 01-before.png, 02-after.png |
|
||||
| 2 | ... | ... | ... | ... |
|
||||
|
||||
**IMPORTANT:** As you show each screenshot and record test results, persist them in shell variables for Step 7:
|
||||
|
||||
```bash
|
||||
# Build these variables during Step 6 — they are required by Step 7's script
|
||||
# NOTE: declare -A requires Bash 4.0+. This is standard on modern systems (macOS ships zsh
|
||||
# but Homebrew bash is 5.x; Linux typically has bash 5.x). If running on Bash <4, use a
|
||||
# plain variable with a lookup function instead.
|
||||
declare -A SCREENSHOT_EXPLANATIONS=(
|
||||
["01-login-page.png"]="Shows the login page loaded successfully with SSO options visible."
|
||||
["02-builder-with-block.png"]="The builder canvas displays the newly added block connected to the trigger."
|
||||
# ... one entry per screenshot, using the same explanations you showed the user above
|
||||
)
|
||||
|
||||
TEST_RESULTS_TABLE="| 1 | Login flow | PASS | N/A | 01-login-before.png, 02-login-after.png |
|
||||
| 2 | Credits purchase | PASS | Before: 100, After: 95 | 03-credits-before.png, 04-credits-after.png |
|
||||
| 3 | Insufficient credits (negative) | PASS | Credits: 0, rejected | 05-insufficient-credits-error.png |"
|
||||
# ... one row per test scenario with actual results
|
||||
```
|
||||
|
||||
## Step 7: Post test report as PR comment with screenshots
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees), then post a comment with inline images and per-screenshot explanations.
|
||||
|
||||
**This step is MANDATORY. Every test run MUST post a PR comment with screenshots. No exceptions.**
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot and build tree JSON
|
||||
# Retry each blob upload up to 3 times. If still failing, list them at end of report.
|
||||
shopt -s nullglob
|
||||
SCREENSHOT_FILES=("$RESULTS_DIR"/*.png)
|
||||
if [ ${#SCREENSHOT_FILES[@]} -eq 0 ]; then
|
||||
echo "ERROR: No screenshots found in $RESULTS_DIR. Test run is incomplete."
|
||||
exit 1
|
||||
fi
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
FAILED_UPLOADS=()
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=""
|
||||
for attempt in 1 2 3; do
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha' 2>/dev/null || true)
|
||||
[ -n "$BLOB_SHA" ] && break
|
||||
sleep 1
|
||||
done
|
||||
if [ -z "$BLOB_SHA" ]; then
|
||||
FAILED_UPLOADS+=("$img")
|
||||
continue
|
||||
fi
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
# Step 2: Create tree, commit, and branch ref
|
||||
TREE_SHA=$(echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
```
|
||||
|
||||
Then post the comment with **inline images AND explanations for each screenshot**:
|
||||
|
||||
```bash
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
|
||||
# Build image markdown using uploaded image URLs; skip FAILED_UPLOADS (listed separately)
|
||||
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in "${SCREENSHOT_FILES[@]}"; do
|
||||
BASENAME=$(basename "$img")
|
||||
TITLE=$(echo "${BASENAME%.png}" | sed 's/^[0-9]*-//' | sed 's/-/ /g' | awk '{for(i=1;i<=NF;i++) $i=toupper(substr($i,1,1)) tolower(substr($i,2))}1')
|
||||
# Skip images that failed to upload — they will be listed at the end
|
||||
IS_FAILED=false
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
[ "$(basename "$failed")" = "$BASENAME" ] && IS_FAILED=true && break
|
||||
done
|
||||
if [ "$IS_FAILED" = true ]; then
|
||||
continue
|
||||
fi
|
||||
EXPLANATION="${SCREENSHOT_EXPLANATIONS[$BASENAME]}"
|
||||
if [ -z "$EXPLANATION" ]; then
|
||||
echo "ERROR: Missing screenshot explanation for $BASENAME. Add it to SCREENSHOT_EXPLANATIONS in Step 6."
|
||||
exit 1
|
||||
fi
|
||||
IMAGE_MARKDOWN="${IMAGE_MARKDOWN}
|
||||
### ${TITLE}
|
||||

|
||||
${EXPLANATION}
|
||||
"
|
||||
done
|
||||
|
||||
# Write comment body to file to avoid shell interpretation issues with special characters
|
||||
COMMENT_FILE=$(mktemp)
|
||||
# If any uploads failed, append a section listing them with instructions
|
||||
FAILED_SECTION=""
|
||||
if [ ${#FAILED_UPLOADS[@]} -gt 0 ]; then
|
||||
FAILED_SECTION="
|
||||
## ⚠️ Failed Screenshot Uploads
|
||||
The following screenshots could not be uploaded via the GitHub API after 3 retries.
|
||||
**To add them:** drag-and-drop or paste these files into a PR comment manually:
|
||||
"
|
||||
for failed in "${FAILED_UPLOADS[@]}"; do
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
- \`$(basename "$failed")\` (local path: \`$failed\`)"
|
||||
done
|
||||
FAILED_SECTION="${FAILED_SECTION}
|
||||
|
||||
**Run status:** INCOMPLETE until the files above are manually attached and visible inline in the PR."
|
||||
fi
|
||||
|
||||
cat > "$COMMENT_FILE" <<INNEREOF
|
||||
## E2E Test Report
|
||||
|
||||
| # | Scenario | Result | API Evidence | Screenshot Evidence |
|
||||
|---|----------|--------|-------------|-------------------|
|
||||
${TEST_RESULTS_TABLE}
|
||||
|
||||
${IMAGE_MARKDOWN}
|
||||
${FAILED_SECTION}
|
||||
INNEREOF
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -F body=@"$COMMENT_FILE"
|
||||
rm -f "$COMMENT_FILE"
|
||||
```
|
||||
|
||||
**The PR comment MUST include:**
|
||||
1. A summary table of all scenarios with PASS/FAIL and before/after API evidence
|
||||
2. Every successfully uploaded screenshot rendered inline; any failed uploads listed with manual attachment instructions
|
||||
3. A 1-2 sentence explanation below each screenshot describing what it proves
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, the standard is HIGHER. Do not just note issues — FIX them immediately.
|
||||
|
||||
### Fix protocol for EVERY issue found (including UX issues):
|
||||
|
||||
1. **Identify** the root cause in the code — read the relevant source files
|
||||
2. **Write a failing test first** (TDD): For backend bugs, write a test marked with `pytest.mark.xfail(reason="...")`. For frontend/Playwright bugs, write a test with `.fixme` annotation. Run it to confirm it fails as expected.
|
||||
3. **Screenshot** the broken state: `agent-browser screenshot $RESULTS_DIR/{NN}-broken-{description}.png`
|
||||
4. **Fix** the code in the worktree
|
||||
5. **Rebuild** ONLY the affected service (not the whole stack):
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose up --build -d {service_name}
|
||||
# e.g., docker compose up --build -d rest_server
|
||||
# e.g., docker compose up --build -d frontend
|
||||
```
|
||||
6. **Wait** for the service to be ready (poll health endpoint)
|
||||
7. **Re-test** the same scenario
|
||||
8. **Screenshot** the fixed state: `agent-browser screenshot $RESULTS_DIR/{NN}-fixed-{description}.png`
|
||||
9. **Remove the xfail/fixme marker** from the test written in step 2, and verify it passes
|
||||
10. **Verify** the fix did not break other scenarios (run a quick smoke test)
|
||||
11. **Commit and push** immediately:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
12. **Continue** to the next test scenario
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find issue (bug OR UX problem) → screenshot broken state
|
||||
→ fix code → rebuild affected service only → re-test → screenshot fixed state
|
||||
→ verify no regressions → commit + push
|
||||
→ repeat for next scenario
|
||||
→ after ALL scenarios pass, run full re-test to verify everything together
|
||||
```
|
||||
|
||||
**Key differences from non-fix mode:**
|
||||
- UX issues count as bugs — fix them (bad alignment, confusing labels, missing loading states)
|
||||
- Every fix MUST have a before/after screenshot pair proving it works
|
||||
- Commit after EACH fix, not in a batch at the end
|
||||
- The final re-test must produce a clean set of all-passing screenshots
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
**Cause:** Supabase auth service schema cache is stale after migration.
|
||||
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
|
||||
|
||||
### Problem: Copilot returns auth errors in subscription mode
|
||||
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
|
||||
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
|
||||
|
||||
### Problem: agent-browser can't find chromium
|
||||
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
|
||||
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
|
||||
|
||||
### Problem: agent-browser selector matches multiple elements
|
||||
**Cause:** `text=X` matches all elements containing that text.
|
||||
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
|
||||
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
### Problem: Docker uses cached layers with old code (PR changes not visible)
|
||||
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
|
||||
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
|
||||
|
||||
### Problem: `agent-browser open` loses login session
|
||||
**Cause:** Without session persistence, `agent-browser open` starts fresh.
|
||||
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
|
||||
|
||||
### Problem: Supabase auth returns "Database error querying schema"
|
||||
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
|
||||
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.
|
||||
195
.claude/skills/setup-repo/SKILL.md
Normal file
195
.claude/skills/setup-repo/SKILL.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
name: setup-repo
|
||||
description: Initialize a worktree-based repo layout for parallel development. Creates a main worktree, a reviews worktree for PR reviews, and N numbered work branches. Handles .env creation, dependency installation, and branchlet config. TRIGGER when user asks to set up the repo from scratch, initialize worktrees, bootstrap their dev environment, "setup repo", "setup worktrees", "initialize dev environment", "set up branches", or when a freshly cloned repo has no sibling worktrees.
|
||||
user-invocable: true
|
||||
args: "No arguments — interactive setup via prompts."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Repository Setup
|
||||
|
||||
This skill sets up a worktree-based development layout from a freshly cloned repo. It creates:
|
||||
- A **main** worktree (the primary checkout)
|
||||
- A **reviews** worktree (for PR reviews)
|
||||
- **N work branches** (branch1..branchN) for parallel development
|
||||
|
||||
## Step 1: Identify the repo
|
||||
|
||||
Determine the repo root and parent directory:
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
REPO_NAME=$(basename "$ROOT")
|
||||
PARENT=$(dirname "$ROOT")
|
||||
```
|
||||
|
||||
Detect if the repo is already inside a worktree layout by counting sibling worktrees (not just checking the directory name, which could be anything):
|
||||
|
||||
```bash
|
||||
# Count worktrees that are siblings (live under $PARENT but aren't $ROOT itself)
|
||||
SIBLING_COUNT=$(git worktree list --porcelain 2>/dev/null | grep "^worktree " | grep -c "$PARENT/" || true)
|
||||
if [ "$SIBLING_COUNT" -gt 1 ]; then
|
||||
echo "INFO: Existing worktree layout detected at $PARENT ($SIBLING_COUNT worktrees)"
|
||||
# Use $ROOT as-is; skip renaming/restructuring
|
||||
else
|
||||
echo "INFO: Fresh clone detected, proceeding with setup"
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 2: Ask the user questions
|
||||
|
||||
Use AskUserQuestion to gather setup preferences:
|
||||
|
||||
1. **How many parallel work branches do you need?** (Options: 4, 8, 16, or custom)
|
||||
- These become `branch1` through `branchN`
|
||||
2. **Which branch should be the base?** (Options: origin/master, origin/dev, or custom)
|
||||
- All work branches and reviews will start from this
|
||||
|
||||
## Step 3: Fetch and set up branches
|
||||
|
||||
```bash
|
||||
cd "$ROOT"
|
||||
git fetch origin
|
||||
|
||||
# Create the reviews branch from base (skip if already exists)
|
||||
if git show-ref --verify --quiet refs/heads/reviews; then
|
||||
echo "INFO: Branch 'reviews' already exists, skipping"
|
||||
else
|
||||
git branch reviews <base-branch>
|
||||
fi
|
||||
|
||||
# Create numbered work branches from base (skip if already exists)
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if git show-ref --verify --quiet "refs/heads/branch$i"; then
|
||||
echo "INFO: Branch 'branch$i' already exists, skipping"
|
||||
else
|
||||
git branch "branch$i" <base-branch>
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 4: Create worktrees
|
||||
|
||||
Create worktrees as siblings to the main checkout:
|
||||
|
||||
```bash
|
||||
if [ -d "$PARENT/reviews" ]; then
|
||||
echo "INFO: Worktree '$PARENT/reviews' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/reviews" reviews
|
||||
fi
|
||||
|
||||
for i in $(seq 1 "$COUNT"); do
|
||||
if [ -d "$PARENT/branch$i" ]; then
|
||||
echo "INFO: Worktree '$PARENT/branch$i' already exists, skipping"
|
||||
else
|
||||
git worktree add "$PARENT/branch$i" "branch$i"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Step 5: Set up environment files
|
||||
|
||||
**Do NOT assume .env files exist.** For each worktree (including main if needed):
|
||||
|
||||
1. Check if `.env` exists in the source worktree for each path
|
||||
2. If `.env` exists, copy it
|
||||
3. If only `.env.default` or `.env.example` exists, copy that as `.env`
|
||||
4. If neither exists, warn the user and list which env files are missing
|
||||
|
||||
Env file locations to check (same as the `/worktree` skill — keep these in sync):
|
||||
- `autogpt_platform/.env`
|
||||
- `autogpt_platform/backend/.env`
|
||||
- `autogpt_platform/frontend/.env`
|
||||
|
||||
> **Note:** This env copying logic intentionally mirrors the `/worktree` skill's approach. If you update the path list or fallback logic here, update `/worktree` as well.
|
||||
|
||||
```bash
|
||||
SOURCE="$ROOT"
|
||||
WORKTREES="reviews"
|
||||
for i in $(seq 1 "$COUNT"); do WORKTREES="$WORKTREES branch$i"; done
|
||||
|
||||
FOUND_ANY_ENV=0
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
for envpath in autogpt_platform autogpt_platform/backend autogpt_platform/frontend; do
|
||||
if [ -f "$SOURCE/$envpath/.env" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$SOURCE/$envpath/.env.default" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.default — you may need to edit it"
|
||||
elif [ -f "$SOURCE/$envpath/.env.example" ]; then
|
||||
FOUND_ANY_ENV=1
|
||||
cp "$SOURCE/$envpath/.env.example" "$TARGET/$envpath/.env"
|
||||
echo "NOTE: $wt/$envpath/.env was created from .env.example — you may need to edit it"
|
||||
else
|
||||
echo "WARNING: No .env, .env.default, or .env.example found at $SOURCE/$envpath/"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [ "$FOUND_ANY_ENV" -eq 0 ]; then
|
||||
echo "WARNING: No environment files or templates were found in the source worktree."
|
||||
# Use AskUserQuestion to confirm: "Continue setup without env files?"
|
||||
# If the user declines, stop here and let them set up .env files first.
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 6: Copy branchlet config
|
||||
|
||||
Copy `.branchlet.json` from main to each worktree so branchlet can manage sub-worktrees:
|
||||
|
||||
```bash
|
||||
if [ -f "$ROOT/.branchlet.json" ]; then
|
||||
for wt in $WORKTREES; do
|
||||
cp "$ROOT/.branchlet.json" "$PARENT/$wt/.branchlet.json"
|
||||
done
|
||||
fi
|
||||
```
|
||||
|
||||
## Step 7: Install dependencies
|
||||
|
||||
Install deps in all worktrees. Run these sequentially per worktree:
|
||||
|
||||
```bash
|
||||
for wt in $WORKTREES; do
|
||||
TARGET="$PARENT/$wt"
|
||||
echo "=== Installing deps for $wt ==="
|
||||
(cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install) &&
|
||||
(cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate) &&
|
||||
(cd "$TARGET/autogpt_platform/frontend" && pnpm install) &&
|
||||
echo "=== Done: $wt ===" ||
|
||||
echo "=== FAILED: $wt ==="
|
||||
done
|
||||
```
|
||||
|
||||
This is slow. Run in background if possible and notify when complete.
|
||||
|
||||
## Step 8: Verify and report
|
||||
|
||||
After setup, verify and report to the user:
|
||||
|
||||
```bash
|
||||
git worktree list
|
||||
```
|
||||
|
||||
Summarize:
|
||||
- Number of worktrees created
|
||||
- Which env files were copied vs created from defaults vs missing
|
||||
- Any warnings or errors encountered
|
||||
|
||||
## Final directory layout
|
||||
|
||||
```
|
||||
parent/
|
||||
main/ # Primary checkout (already exists)
|
||||
reviews/ # PR review worktree
|
||||
branch1/ # Work branch 1
|
||||
branch2/ # Work branch 2
|
||||
...
|
||||
branchN/ # Work branch N
|
||||
```
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
85
.claude/skills/worktree/SKILL.md
Normal file
85
.claude/skills/worktree/SKILL.md
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
name: worktree
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "3.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Create the worktree
|
||||
|
||||
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
PARENT=$(dirname "$ROOT")
|
||||
|
||||
# From an existing branch
|
||||
git worktree add "$PARENT/<NAME>" <branch-name>
|
||||
|
||||
# From a new branch off dev
|
||||
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
||||
```
|
||||
|
||||
## Copy environment files
|
||||
|
||||
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
TARGET="$(dirname "$ROOT")/<NAME>"
|
||||
|
||||
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
||||
if [ -f "$ROOT/$envpath/.env" ]; then
|
||||
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
||||
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
||||
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
||||
```
|
||||
|
||||
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
||||
|
||||
## Running the app (optional)
|
||||
|
||||
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot testing
|
||||
|
||||
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
|
||||
## Cleanup
|
||||
|
||||
```bash
|
||||
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
||||
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
```
|
||||
|
||||
## Alternative: Branchlet (optional)
|
||||
|
||||
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
||||
|
||||
```bash
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
```
|
||||
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,8 +1,12 @@
|
||||
<!-- Clearly explain the need for these changes: -->
|
||||
### Why / What / How
|
||||
|
||||
<!-- Why: Why does this PR exist? What problem does it solve, or what's broken/missing without it? -->
|
||||
<!-- What: What does this PR change? Summarize the changes at a high level. -->
|
||||
<!-- How: How does it work? Describe the approach, key implementation details, or architecture decisions. -->
|
||||
|
||||
### Changes 🏗️
|
||||
|
||||
<!-- Concisely describe all of the changes made in this pull request: -->
|
||||
<!-- List the key changes. Keep it higher level than the diff but specific enough to highlight what's new/modified. -->
|
||||
|
||||
### Checklist 📋
|
||||
|
||||
|
||||
116
.github/workflows/platform-backend-ci.yml
vendored
116
.github/workflows/platform-backend-ci.yml
vendored
@@ -5,12 +5,14 @@ on:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
merge_group:
|
||||
@@ -25,10 +27,91 @@ defaults:
|
||||
working-directory: autogpt_platform/backend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py3.12-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run Linters
|
||||
run: poetry run lint --skip-pyright
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
type-check:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Run Pyright
|
||||
run: poetry run pyright --pythonversion ${{ matrix.python-version }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -96,9 +179,9 @@ jobs:
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
@@ -156,22 +239,22 @@ jobs:
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
@@ -186,18 +269,13 @@ jobs:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run pytest
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
@@ -209,6 +287,12 @@ jobs:
|
||||
REDIS_PORT: "6379"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
@@ -222,9 +306,3 @@ jobs:
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
|
||||
169
.github/workflows/platform-frontend-ci.yml
vendored
169
.github/workflows/platform-frontend-ci.yml
vendored
@@ -120,175 +120,6 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
312
.github/workflows/platform-fullstack-ci.yml
vendored
312
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,14 +1,18 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
name: AutoGPT Platform - Full-stack CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
@@ -24,42 +28,28 @@ defaults:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies to populate cache
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: big-boi
|
||||
check-api-types:
|
||||
name: check API types
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -67,70 +57,256 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
# ------------------------ Backend setup ------------------------
|
||||
|
||||
- name: Set up Backend - Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Backend - Install Poetry
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Installing Poetry version ${POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
|
||||
- name: Set up Backend - Set up dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Set up Backend - Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Set up Backend - Generate Prisma client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Set up Frontend - Export OpenAPI schema from Backend
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
|
||||
# ------------------------ Frontend setup ------------------------
|
||||
|
||||
- name: Set up Frontend - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Frontend - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Set up Frontend - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
- name: Set up Frontend - Format OpenAPI schema
|
||||
id: format-schema
|
||||
run: pnpm prettier --write ./src/app/api/openapi.json
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "The openapi.json file has been modified after exporting the API schema."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo "\nIn the backend directory:"
|
||||
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
|
||||
echo "\nIn the frontend directory:"
|
||||
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
|
||||
echo "3. Run 'pnpm generate:api'"
|
||||
echo "4. Run 'pnpm types'"
|
||||
echo "5. Fix any TypeScript errors that may have been introduced"
|
||||
echo "6. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
- name: Set up Frontend - Generate API client
|
||||
id: generate-api-client
|
||||
run: pnpm orval --config ./orval.config.ts
|
||||
# Continue with type generation & check even if there are schema changes
|
||||
if: success() || (steps.format-schema.outcome == 'success')
|
||||
|
||||
- name: Check for TypeScript errors
|
||||
run: pnpm types
|
||||
if: success() || (steps.generate-api-client.outcome == 'success')
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: autogpt_platform/frontend/playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: autogpt_platform/frontend/test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
This guide provides context for coding agents when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
|
||||
@@ -83,13 +83,13 @@ The AutoGPT frontend is where users interact with our powerful AI automation pla
|
||||
|
||||
**Agent Builder:** For those who want to customize, our intuitive, low-code interface allows you to design and configure your own AI agents.
|
||||
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
**Workflow Management:** Build, modify, and optimize your automation workflows with ease. You build your agent by connecting blocks, where each block performs a single action.
|
||||
|
||||
**Deployment Controls:** Manage the lifecycle of your agents, from testing to production.
|
||||
|
||||
**Ready-to-Use Agents:** Don't want to build? Simply select from our library of pre-configured agents and put them to work immediately.
|
||||
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
**Agent Interaction:** Whether you've built your own or are using pre-configured agents, easily run and interact with them through our user-friendly interface.
|
||||
|
||||
**Monitoring and Analytics:** Keep track of your agents' performance and gain insights to continually improve your automation processes.
|
||||
|
||||
|
||||
120
autogpt_platform/AGENTS.md
Normal file
120
autogpt_platform/AGENTS.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# AutoGPT Platform
|
||||
|
||||
This file provides guidance to coding agents when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/AGENTS.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/AGENTS.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- **Split PRs by concern** — each PR should have a single clear purpose. For example, "usage tracking" and "credit charging" should be separate PRs even if related. Combining multiple concerns makes it harder for reviewers to understand what belongs to what.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- **Structure the PR description with Why / What / How** — Why: the motivation (what problem it solves, what's broken/missing without it); What: high-level summary of changes; How: approach, key implementation details, or architecture decisions. Reviewers need all three to judge whether the approach fits the problem.
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
|
||||
```bash
|
||||
PR_BODY=$(mktemp)
|
||||
cat > "$PR_BODY" << 'PREOF'
|
||||
## Summary
|
||||
- use `backticks` freely here
|
||||
PREOF
|
||||
gh pr create --title "..." --body-file "$PR_BODY" --base dev
|
||||
rm "$PR_BODY"
|
||||
```
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, follow a test-first approach:
|
||||
|
||||
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
|
||||
2. **Implement the fix/feature** — write the minimal code to make the test pass.
|
||||
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
|
||||
|
||||
This ensures every change is covered by a test and that the test actually validates the intended behavior.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -1,95 +1 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@AGENTS.md
|
||||
|
||||
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
@@ -0,0 +1,40 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.auth_activities
|
||||
-- Looker source alias: ds49 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Tracks authentication events (login, logout, SSO, password
|
||||
-- reset, etc.) from Supabase's internal audit log.
|
||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.audit_log_entries — Supabase internal auth event log
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||
-- actor_id TEXT User ID who triggered the event
|
||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days from current date
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily login counts
|
||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||
-- FROM analytics.auth_activities
|
||||
-- WHERE action = 'login'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- SSO vs password login breakdown
|
||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||
-- WHERE action = 'login' GROUP BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
created_at,
|
||||
payload->>'actor_id' AS actor_id,
|
||||
payload->>'actor_via_sso' AS actor_via_sso,
|
||||
payload->>'action' AS action
|
||||
FROM auth.audit_log_entries
|
||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
@@ -0,0 +1,105 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.graph_execution
|
||||
-- Looker source alias: ds16 | Charts: 21
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per agent graph execution (last 90 days).
|
||||
-- Unpacks the JSONB stats column into individual numeric columns
|
||||
-- and normalises the executionStatus — runs that failed due to
|
||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||
-- to allow safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Execution UUID
|
||||
-- agentGraphId TEXT Agent graph UUID
|
||||
-- agentGraphVersion INT Graph version number
|
||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||
-- updatedAt TIMESTAMPTZ Last status update time
|
||||
-- userId TEXT Owner user UUID
|
||||
-- agentGraphName TEXT Human-readable agent name
|
||||
-- cputime DECIMAL Total CPU seconds consumed
|
||||
-- walltime DECIMAL Total wall-clock seconds
|
||||
-- node_count DECIMAL Number of nodes in the graph
|
||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||
-- execution_cost DECIMAL Credit cost of this execution
|
||||
-- correctness_score FLOAT AI correctness score (if available)
|
||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily execution counts by status
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||
-- FROM analytics.graph_execution
|
||||
-- GROUP BY 1, 2 ORDER BY 1;
|
||||
--
|
||||
-- -- Average cost per execution by agent
|
||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'COMPLETED'
|
||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||
--
|
||||
-- -- Top error messages
|
||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ge."id" AS id,
|
||||
ge."agentGraphId" AS agentGraphId,
|
||||
ge."agentGraphVersion" AS agentGraphVersion,
|
||||
CASE
|
||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||
AND (
|
||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||
)
|
||||
THEN 'NO_CREDITS'
|
||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||
END AS executionStatus,
|
||||
ge."createdAt" AS createdAt,
|
||||
ge."updatedAt" AS updatedAt,
|
||||
ge."userId" AS userId,
|
||||
g."name" AS agentGraphName,
|
||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage
|
||||
FROM platform."AgentGraphExecution" ge
|
||||
LEFT JOIN platform."AgentGraph" g
|
||||
ON ge."agentGraphId" = g."id"
|
||||
AND ge."agentGraphVersion" = g."version"
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||
"userId", "agentGraphId",
|
||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||
FROM platform."LibraryAgent"
|
||||
WHERE "isDeleted" = FALSE
|
||||
AND "isArchived" = FALSE
|
||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
@@ -0,0 +1,101 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.node_block_execution
|
||||
-- Looker source alias: ds14 | Charts: 11
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per node (block) execution (last 90 days).
|
||||
-- Unpacks stats JSONB and joins to identify which block type
|
||||
-- was run. For failed nodes, joins the error output and
|
||||
-- scrubs it for safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentNodeExecution — Node execution records
|
||||
-- platform.AgentNode — Node → block mapping
|
||||
-- platform.AgentBlock — Block name/ID
|
||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Node execution UUID
|
||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||
-- agentNodeId TEXT Node UUID within the graph
|
||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||
-- addedTime TIMESTAMPTZ When the node was queued
|
||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||
-- startedTime TIMESTAMPTZ When execution started
|
||||
-- endedTime TIMESTAMPTZ When execution finished
|
||||
-- inputSize BIGINT Input payload size in bytes
|
||||
-- outputSize BIGINT Output payload size in bytes
|
||||
-- walltime NUMERIC Wall-clock seconds for this node
|
||||
-- cputime NUMERIC CPU seconds for this node
|
||||
-- llmRetryCount INT Number of LLM retries
|
||||
-- llmCallCount INT Number of LLM API calls made
|
||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||
-- outputTokenCount BIGINT LLM output tokens produced
|
||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||
-- blockId TEXT Block UUID
|
||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Most-used blocks by execution count
|
||||
-- SELECT "blockName", COUNT(*) AS executions,
|
||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||
-- FROM analytics.node_block_execution
|
||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||
--
|
||||
-- -- Average LLM token usage per block
|
||||
-- SELECT "blockName",
|
||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "llmCallCount" > 0
|
||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||
--
|
||||
-- -- Top failure reasons
|
||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ne."id" AS id,
|
||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||
ne."agentNodeId" AS agentNodeId,
|
||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||
ne."addedTime" AS addedTime,
|
||||
ne."queuedTime" AS queuedTime,
|
||||
ne."startedTime" AS startedTime,
|
||||
ne."endedTime" AS endedTime,
|
||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||
b."name" AS blockName,
|
||||
b."id" AS blockId,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM eio."data"::text),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage,
|
||||
eio."data" AS errorMessage
|
||||
FROM platform."AgentNodeExecution" ne
|
||||
LEFT JOIN platform."AgentNode" nd
|
||||
ON ne."agentNodeId" = nd."id"
|
||||
LEFT JOIN platform."AgentBlock" b
|
||||
ON nd."agentBlockId" = b."id"
|
||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||
ON eio."referencedByOutputExecId" = ne."id"
|
||||
AND eio."name" = 'error'
|
||||
AND ne."executionStatus" = 'FAILED'
|
||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
@@ -0,0 +1,97 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_agent
|
||||
-- Looker source alias: ds35 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention broken down per individual agent.
|
||||
-- Cohort = week of a user's first use of THAT specific agent.
|
||||
-- Tells you which agents keep users coming back vs. one-shot
|
||||
-- use. Only includes cohorts from the last 180 days.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||
-- platform.AgentGraph — Agent names
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- agent_id TEXT Agent graph UUID
|
||||
-- agent_label TEXT 'AgentName [first8chars]'
|
||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||
-- cohort_week_start DATE Week users first ran this agent
|
||||
-- cohort_label TEXT ISO week label
|
||||
-- cohort_label_n TEXT ISO week label with cohort size
|
||||
-- user_lifetime_week INT Weeks since first use of this agent
|
||||
-- cohort_users BIGINT Users in this cohort for this agent
|
||||
-- active_users BIGINT Users who ran the agent again in week k
|
||||
-- retention_rate FLOAT active_users / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Best-retained agents at week 2
|
||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||
-- FROM analytics.retention_agent
|
||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||
--
|
||||
-- -- Agents with most unique users
|
||||
-- SELECT DISTINCT agent_label, agent_total_users
|
||||
-- FROM analytics.retention_agent
|
||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||
e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e
|
||||
),
|
||||
first_use AS (
|
||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1,2
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||
),
|
||||
active_counts AS (
|
||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||
),
|
||||
cohort_sizes AS (
|
||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||
),
|
||||
cohort_caps AS (
|
||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
),
|
||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||
SELECT
|
||||
g.agent_id,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(ac.active_users,0) AS active_users,
|
||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||
FROM grid g
|
||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -0,0 +1,81 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_daily
|
||||
-- Looker source alias: ds111 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on agent executions.
|
||||
-- Cohort anchor = day of user's FIRST ever execution.
|
||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||
-- Great for early engagement analysis (did users run another
|
||||
-- agent the next day?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_daily.
|
||||
-- cohort_day_start = day of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-3 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||
-- FROM analytics.retention_execution_daily
|
||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fe.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -0,0 +1,81 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_weekly
|
||||
-- Looker source alias: ds92 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on agent executions.
|
||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||
-- (not first login). Only includes cohorts from the last 180 days.
|
||||
-- Useful when you care about product engagement, not just visits.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_weekly.
|
||||
-- cohort_week_start = week of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-2 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded
|
||||
-- FROM analytics.retention_execution_weekly
|
||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fe.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
@@ -0,0 +1,94 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_daily
|
||||
-- Looker source alias: ds112 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on login sessions.
|
||||
-- Same logic as retention_login_weekly but at day granularity,
|
||||
-- showing up to day 30 for cohorts from the last 90 days.
|
||||
-- Useful for analysing early activation (days 1-7) in detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||
-- cohort_day_start DATE First day the cohort logged in
|
||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||
-- cohort_users BIGINT Total users in cohort
|
||||
-- active_users_bounded BIGINT Users active on exactly day k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-1 retention rate (came back next day)
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||
--
|
||||
-- -- Average retention curve across all cohorts
|
||||
-- SELECT user_lifetime_day,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fl.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -0,0 +1,96 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_onboarded_weekly
|
||||
-- Looker source alias: ds101 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention from login sessions, restricted to
|
||||
-- users who "onboarded" — defined as running at least one
|
||||
-- agent within 365 days of their first login.
|
||||
-- Filters out users who signed up but never activated,
|
||||
-- giving a cleaner view of engaged-user retention.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||
-- Only difference: cohort is filtered to onboarded users only.
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Compare week-4 retention: all users vs onboarded only
|
||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||
-- UNION ALL
|
||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login_all AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
onboarders AS (
|
||||
SELECT fl.user_id FROM first_login_all fl
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||
WHERE e."userId"::text = fl.user_id
|
||||
AND e."createdAt" >= fl.first_login_time
|
||||
AND e."createdAt" < fl.first_login_time
|
||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||
)
|
||||
),
|
||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
@@ -0,0 +1,103 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_weekly
|
||||
-- Looker source alias: ds83 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on login sessions.
|
||||
-- Users are grouped by the ISO week of their first ever login.
|
||||
-- For each cohort × lifetime-week combination, outputs both:
|
||||
-- - bounded rate: % active in exactly that week
|
||||
-- - unbounded rate: % who were ever active on or after that week
|
||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- HOW TO READ THE OUTPUT
|
||||
-- cohort_week_start The Monday of the week users first logged in
|
||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- cohort_week_start DATE First day of the cohort's signup week
|
||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||
-- active_users_bounded BIGINT Users active in exactly week k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-1 retention rate per cohort
|
||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- WHERE user_lifetime_week = 1
|
||||
-- ORDER BY cohort_week_start;
|
||||
--
|
||||
-- -- Overall average retention curve (all cohorts combined)
|
||||
-- SELECT user_lifetime_week,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
@@ -0,0 +1,71 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_block_spending
|
||||
-- Looker source alias: ds6 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per credit transaction (last 90 days).
|
||||
-- Shows how users spend credits broken down by block type,
|
||||
-- LLM provider and model. Joins node execution stats for
|
||||
-- token-level detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.CreditTransaction — Credit debit/credit records
|
||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- transactionKey TEXT Unique transaction identifier
|
||||
-- userId TEXT User who was charged
|
||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||
-- blockId TEXT Block UUID that triggered the spend
|
||||
-- blockName TEXT Human-readable block name
|
||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||
-- node_exec_id TEXT Linked node execution UUID
|
||||
-- llm_call_count INT LLM API calls made in that execution
|
||||
-- llm_retry_count INT LLM retries in that execution
|
||||
-- llm_input_token_count INT Input tokens consumed
|
||||
-- llm_output_token_count INT Output tokens produced
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend per user (last 90 days)
|
||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "transactionType" = 'USAGE'
|
||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||
--
|
||||
-- -- Spend by LLM provider + model
|
||||
-- SELECT "llm_provider", "llm_model",
|
||||
-- SUM("negativeAmount") AS total_cost,
|
||||
-- SUM("llm_input_token_count") AS input_tokens,
|
||||
-- SUM("llm_output_token_count") AS output_tokens
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "llm_provider" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
c."transactionKey" AS transactionKey,
|
||||
c."userId" AS userId,
|
||||
c."amount" AS amount,
|
||||
c."amount" * -1 AS negativeAmount,
|
||||
c."type" AS transactionType,
|
||||
c."createdAt" AS transactionTime,
|
||||
c.metadata->>'block_id' AS blockId,
|
||||
c.metadata->>'block' AS blockName,
|
||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||
c.metadata->'input'->>'model' AS llm_model,
|
||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||
FROM platform."CreditTransaction" c
|
||||
LEFT JOIN platform."AgentNodeExecution" ne
|
||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
@@ -0,0 +1,45 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding
|
||||
-- Looker source alias: ds68 | Charts: 3
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user onboarding record. Contains the user's
|
||||
-- stated usage reason, selected integrations, completed
|
||||
-- onboarding steps and optional first agent selection.
|
||||
-- Full history (no date filter) since onboarding happens
|
||||
-- once per user.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding state per user
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Onboarding record UUID
|
||||
-- createdAt TIMESTAMPTZ When onboarding started
|
||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||
-- integrations TEXT[] Array of integration names the user selected
|
||||
-- userId TEXT User UUID
|
||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Usage reason breakdown
|
||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||
--
|
||||
-- -- Completion rate per step
|
||||
-- SELECT step, COUNT(*) AS users_completed
|
||||
-- FROM analytics.user_onboarding
|
||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"usageReason",
|
||||
integrations,
|
||||
"userId",
|
||||
"completedSteps",
|
||||
"selectedStoreListingVersionId"
|
||||
FROM platform."UserOnboarding"
|
||||
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
@@ -0,0 +1,100 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_funnel
|
||||
-- Looker source alias: ds74 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated onboarding funnel showing how many users
|
||||
-- completed each step and the drop-off percentage from the
|
||||
-- previous step. One row per onboarding step (all 22 steps
|
||||
-- always present, even with 0 completions — prevents sparse
|
||||
-- gaps from making LAG compare the wrong predecessors).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||
-- users_completed BIGINT Distinct users who completed this step
|
||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||
--
|
||||
-- STEP ORDER
|
||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full funnel
|
||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||
--
|
||||
-- -- Biggest drop-off point
|
||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||
-- =============================================================
|
||||
|
||||
WITH all_steps AS (
|
||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||
-- are always present, keeping LAG comparisons correct.
|
||||
SELECT step_name, step_order
|
||||
FROM (VALUES
|
||||
('WELCOME', 1),
|
||||
('USAGE_REASON', 2),
|
||||
('INTEGRATIONS', 3),
|
||||
('AGENT_CHOICE', 4),
|
||||
('AGENT_NEW_RUN', 5),
|
||||
('AGENT_INPUT', 6),
|
||||
('CONGRATS', 7),
|
||||
('GET_RESULTS', 8),
|
||||
('MARKETPLACE_VISIT', 9),
|
||||
('MARKETPLACE_ADD_AGENT', 10),
|
||||
('MARKETPLACE_RUN_AGENT', 11),
|
||||
('BUILDER_OPEN', 12),
|
||||
('BUILDER_SAVE_AGENT', 13),
|
||||
('BUILDER_RUN_AGENT', 14),
|
||||
('VISIT_COPILOT', 15),
|
||||
('RE_RUN_AGENT', 16),
|
||||
('SCHEDULE_AGENT', 17),
|
||||
('RUN_AGENTS', 18),
|
||||
('RUN_3_DAYS', 19),
|
||||
('TRIGGER_WEBHOOK', 20),
|
||||
('RUN_14_DAYS', 21),
|
||||
('RUN_AGENTS_100', 22)
|
||||
) AS t(step_name, step_order)
|
||||
),
|
||||
raw AS (
|
||||
SELECT
|
||||
u."userId",
|
||||
step_txt::text AS step
|
||||
FROM platform."UserOnboarding" u
|
||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
),
|
||||
step_counts AS (
|
||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||
FROM raw GROUP BY step
|
||||
),
|
||||
funnel AS (
|
||||
SELECT
|
||||
a.step_name AS step,
|
||||
a.step_order,
|
||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||
ROUND(
|
||||
100.0 * COALESCE(sc.users_completed, 0)
|
||||
/ NULLIF(
|
||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||
0
|
||||
),
|
||||
2
|
||||
) AS pct_from_prev
|
||||
FROM all_steps a
|
||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||
)
|
||||
SELECT * FROM funnel ORDER BY step_order
|
||||
@@ -0,0 +1,41 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_integration
|
||||
-- Looker source alias: ds75 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated count of users who selected each integration
|
||||
-- during onboarding. One row per integration type, sorted
|
||||
-- by popularity.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — integrations array column
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full integration popularity ranking
|
||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||
--
|
||||
-- -- Top 5 integrations
|
||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||
-- =============================================================
|
||||
|
||||
WITH exploded AS (
|
||||
SELECT
|
||||
u."userId" AS user_id,
|
||||
UNNEST(u."integrations") AS integration
|
||||
FROM platform."UserOnboarding" u
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
)
|
||||
SELECT
|
||||
integration,
|
||||
COUNT(DISTINCT user_id) AS users_with_integration
|
||||
FROM exploded
|
||||
WHERE integration IS NOT NULL AND integration <> ''
|
||||
GROUP BY integration
|
||||
ORDER BY users_with_integration DESC
|
||||
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
@@ -0,0 +1,145 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.users_activities
|
||||
-- Looker source alias: ds56 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user with lifetime activity summary.
|
||||
-- Joins login sessions with agent graphs, executions and
|
||||
-- node-level runs to give a full picture of how engaged
|
||||
-- each user is. Includes a convenience flag for 7-day
|
||||
-- activation (did the user return at least 7 days after
|
||||
-- their first login?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login/session records
|
||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||
-- platform.AgentGraphExecution — Agent run history
|
||||
-- platform.AgentNodeExecution — Individual block execution history
|
||||
--
|
||||
-- PERFORMANCE NOTE
|
||||
-- Each CTE aggregates its own table independently by userId.
|
||||
-- This avoids the fan-out that occurs when driving every join
|
||||
-- from user_logins across the two largest tables
|
||||
-- (AgentGraphExecution and AgentNodeExecution).
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- user_id TEXT Supabase user UUID
|
||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||
-- node_execution_count BIGINT Total node executions across all runs
|
||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Users who ran at least one agent and returned after 7 days
|
||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||
--
|
||||
-- -- Top 10 most active users by agent runs
|
||||
-- SELECT user_id, agent_runs, node_execution_count
|
||||
-- FROM analytics.users_activities
|
||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||
--
|
||||
-- -- 7-day activation rate
|
||||
-- SELECT
|
||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||
-- AS activation_rate
|
||||
-- FROM analytics.users_activities;
|
||||
-- =============================================================
|
||||
|
||||
WITH user_logins AS (
|
||||
SELECT
|
||||
user_id::text AS user_id,
|
||||
MIN(created_at) AS first_login_time,
|
||||
MAX(created_at) AS last_login_time,
|
||||
GREATEST(
|
||||
MAX(refreshed_at)::timestamptz,
|
||||
MAX(created_at)::timestamptz
|
||||
) AS last_visit_time
|
||||
FROM auth.sessions
|
||||
GROUP BY user_id
|
||||
),
|
||||
user_agents AS (
|
||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MAX("updatedAt") AS last_agent_save_time,
|
||||
COUNT(DISTINCT "id") AS agent_count
|
||||
FROM platform."AgentGraph"
|
||||
WHERE "isActive"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_graph_runs AS (
|
||||
-- Aggregate AgentGraphExecution directly by userId
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MIN("createdAt") AS first_agent_run_time,
|
||||
MAX("createdAt") AS last_agent_run_time,
|
||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||
COUNT("id") AS agent_runs
|
||||
FROM platform."AgentGraphExecution"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_node_runs AS (
|
||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||
-- single join to AgentGraphExecution instead of fanning out from
|
||||
-- user_logins through both large tables.
|
||||
SELECT
|
||||
g."userId"::text AS user_id,
|
||||
COUNT(*) AS node_execution_count,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||
FROM platform."AgentNodeExecution" n
|
||||
JOIN platform."AgentGraphExecution" g
|
||||
ON g."id" = n."agentGraphExecutionId"
|
||||
GROUP BY g."userId"
|
||||
)
|
||||
SELECT
|
||||
ul.user_id,
|
||||
ul.first_login_time,
|
||||
ul.last_login_time,
|
||||
ul.last_visit_time,
|
||||
ua.last_agent_save_time,
|
||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||
gr.first_agent_run_time,
|
||||
gr.last_agent_run_time,
|
||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||
CASE
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||
ELSE NULL
|
||||
END AS is_active_after_7d,
|
||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||
FROM user_logins ul
|
||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||
54
autogpt_platform/autogpt_libs/poetry.lock
generated
54
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
||||
optional = false
|
||||
python-versions = "<3.11,>=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||
@@ -541,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||
@@ -2181,14 +2181,14 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "7.0.0"
|
||||
version = "7.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||
{file = "pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678"},
|
||||
{file = "pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2342,30 +2342,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.0"
|
||||
version = "0.15.7"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.15.0-py3-none-linux_armv6l.whl", hash = "sha256:aac4ebaa612a82b23d45964586f24ae9bc23ca101919f5590bdb368d74ad5455"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dcd4be7cc75cfbbca24a98d04d0b9b36a270d0833241f776b788d59f4142b14d"},
|
||||
{file = "ruff-0.15.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d747e3319b2bce179c7c1eaad3d884dc0a199b5f4d5187620530adf9105268ce"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:650bd9c56ae03102c51a5e4b554d74d825ff3abe4db22b90fd32d816c2e90621"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a6664b7eac559e3048223a2da77769c2f92b43a6dfd4720cef42654299a599c9"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f811f97b0f092b35320d1556f3353bf238763420ade5d9e62ebd2b73f2ff179"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:761ec0a66680fab6454236635a39abaf14198818c8cdf691e036f4bc0f406b2d"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:940f11c2604d317e797b289f4f9f3fa5555ffe4fb574b55ed006c3d9b6f0eb78"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbca3d40558789126da91d7ef9a7c87772ee107033db7191edefa34e2c7f1b4"},
|
||||
{file = "ruff-0.15.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9a121a96db1d75fa3eb39c4539e607f628920dd72ff1f7c5ee4f1b768ac62d6e"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5298d518e493061f2eabd4abd067c7e4fb89e2f63291c94332e35631c07c3662"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:afb6e603d6375ff0d6b0cee563fa21ab570fd15e65c852cb24922cef25050cf1"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:77e515f6b15f828b94dc17d2b4ace334c9ddb7d9468c54b2f9ed2b9c1593ef16"},
|
||||
{file = "ruff-0.15.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6f6e80850a01eb13b3e42ee0ebdf6e4497151b48c35051aab51c101266d187a3"},
|
||||
{file = "ruff-0.15.0-py3-none-win32.whl", hash = "sha256:238a717ef803e501b6d51e0bdd0d2c6e8513fe9eec14002445134d3907cd46c3"},
|
||||
{file = "ruff-0.15.0-py3-none-win_amd64.whl", hash = "sha256:dd5e4d3301dc01de614da3cdffc33d4b1b96fb89e45721f1598e5532ccf78b18"},
|
||||
{file = "ruff-0.15.0-py3-none-win_arm64.whl", hash = "sha256:c480d632cc0ca3f0727acac8b7d053542d9e114a462a145d0b00e7cd658c515a"},
|
||||
{file = "ruff-0.15.0.tar.gz", hash = "sha256:6bdea47cdbea30d40f8f8d7d69c0854ba7c15420ec75a26f463290949d7f7e9a"},
|
||||
{file = "ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477"},
|
||||
{file = "ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5"},
|
||||
{file = "ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d"},
|
||||
{file = "ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580"},
|
||||
{file = "ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de"},
|
||||
{file = "ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1"},
|
||||
{file = "ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2"},
|
||||
{file = "ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2564,7 +2564,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
@@ -2912,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
content-hash = "e0936a065565550afed18f6298b7e04e814b44100def7049f1a0d68662624a39"
|
||||
|
||||
@@ -26,8 +26,8 @@ pyright = "^1.1.408"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-cov = "^7.1.0"
|
||||
ruff = "^0.15.7"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -178,6 +178,7 @@ SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
AGENTMAIL_API_KEY=
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
|
||||
227
autogpt_platform/backend/AGENTS.md
Normal file
227
autogpt_platform/backend/AGENTS.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Backend
|
||||
|
||||
This file provides guidance to coding agents when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **Absolute imports** — use `from backend.module import ...` for cross-package imports. Single-dot relative (`from .sibling import ...`) is acceptable for sibling modules within the same package (e.g., blocks). Avoid double-dot relative imports (`from ..parent import ...`) — use the absolute path instead
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
### Test-Driven Development (TDD)
|
||||
|
||||
When fixing a bug or adding a feature, write the test **before** the implementation:
|
||||
|
||||
```python
|
||||
# 1. Write a failing test marked xfail
|
||||
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
|
||||
# 2. Run it — confirm it fails (XFAIL)
|
||||
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
|
||||
|
||||
# 3. Implement the fix
|
||||
|
||||
# 4. Remove xfail, run again — confirm it passes
|
||||
def test_widget_handles_empty_input():
|
||||
result = widget.process("")
|
||||
assert result == Widget.EMPTY_RESULT
|
||||
```
|
||||
|
||||
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -1,170 +1 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@AGENTS.md
|
||||
|
||||
@@ -50,7 +50,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
@@ -82,7 +82,7 @@ RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
@@ -121,19 +121,21 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
# Install agent-browser (Copilot browser tool) using the system chromium package.
|
||||
# Chrome for Testing (the binary agent-browser downloads via `agent-browser install`)
|
||||
# has no ARM64 builds, so we use the distro-packaged chromium instead — verified to
|
||||
# work with agent-browser via Docker tests on arm64; amd64 is validated in CI.
|
||||
# Note: system chromium tracks the Debian package schedule rather than a pinned
|
||||
# Chrome for Testing release. If agent-browser requires a specific Chrome version,
|
||||
# verify compatibility against the chromium package version in the base image.
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends chromium fonts-liberation \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
ENV AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -18,14 +18,22 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.integrations.models import get_all_provider_names
|
||||
from backend.api.features.integrations.router import (
|
||||
CredentialsMetaResponse,
|
||||
to_meta_response,
|
||||
)
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
@@ -91,18 +99,6 @@ class OAuthCompleteResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CredentialSummary(BaseModel):
|
||||
"""Summary of a credential without sensitive data."""
|
||||
|
||||
id: str
|
||||
provider: str
|
||||
type: CredentialsType
|
||||
title: Optional[str] = None
|
||||
scopes: Optional[list[str]] = None
|
||||
username: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about an integration provider."""
|
||||
|
||||
@@ -473,12 +469,12 @@ async def complete_oauth(
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||
@integrations_router.get("/credentials", response_model=list[CredentialsMetaResponse])
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List all credentials for the authenticated user.
|
||||
|
||||
@@ -486,28 +482,19 @@ async def list_credentials(
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||
"/{provider}/credentials", response_model=list[CredentialsMetaResponse]
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> list[CredentialSummary]:
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
"""
|
||||
List credentials for a specific provider.
|
||||
"""
|
||||
@@ -515,16 +502,7 @@ async def list_credentials_by_provider(
|
||||
auth.user_id, provider
|
||||
)
|
||||
return [
|
||||
CredentialSummary(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -597,11 +575,11 @@ async def create_credential(
|
||||
# Store credentials
|
||||
try:
|
||||
await creds_manager.create(auth.user_id, credentials)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
|
||||
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||
@@ -639,15 +617,23 @@ async def delete_credential(
|
||||
use the main API's delete endpoint which handles webhook cleanup and
|
||||
token revocation.
|
||||
"""
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(auth.user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
|
||||
await creds_manager.delete(auth.user_id, cred_id)
|
||||
|
||||
@@ -72,7 +72,7 @@ class RunAgentRequest(BaseModel):
|
||||
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
return ChatSession.new(user_id, dry_run=False)
|
||||
|
||||
|
||||
@tools_router.post(
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Admin endpoints for checking and resetting user CoPilot rate limit usage."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
reset_user_usage,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["copilot", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimitResponse(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
daily_token_limit: int
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
user_id: Optional[str], email: Optional[str]
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Resolve a user_id and email from the provided parameters.
|
||||
|
||||
Returns (user_id, email). Accepts either user_id or email; at least one
|
||||
must be provided. When both are provided, ``email`` takes precedence.
|
||||
"""
|
||||
if email:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No user found with the provided email."
|
||||
)
|
||||
return user.id, email
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either user_id or email query parameter is required.",
|
||||
)
|
||||
|
||||
# We have a user_id; try to look up their email for display purposes.
|
||||
# This is non-critical -- a failure should not block the response.
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
return user_id, resolved_email
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Get User Rate Limit",
|
||||
)
|
||||
async def get_user_rate_limit(
|
||||
user_id: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Get a user's current usage and effective rate limits. Admin-only.
|
||||
|
||||
Accepts either ``user_id`` or ``email`` as a query parameter.
|
||||
When ``email`` is provided the user is looked up by email first.
|
||||
"""
|
||||
resolved_id, resolved_email = await _resolve_user_id(user_id, email)
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/reset",
|
||||
response_model=UserRateLimitResponse,
|
||||
summary="Reset User Rate Limit Usage",
|
||||
)
|
||||
async def reset_user_rate_limit(
|
||||
user_id: str = Body(embed=True),
|
||||
reset_weekly: bool = Body(False, embed=True),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserRateLimitResponse:
|
||||
"""Reset a user's daily usage counter (and optionally weekly). Admin-only."""
|
||||
logger.info(
|
||||
"Admin %s resetting rate limit for user %s (reset_weekly=%s)",
|
||||
admin_user_id,
|
||||
user_id,
|
||||
reset_weekly,
|
||||
)
|
||||
|
||||
try:
|
||||
await reset_user_usage(user_id, reset_weekly=reset_weekly)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve email for user %s", user_id, exc_info=True)
|
||||
resolved_email = None
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=user_id,
|
||||
user_email=resolved_email,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(rate_limit_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
_MOCK_MODULE = "backend.api.features.admin.rate_limit_admin_routes"
|
||||
|
||||
_TARGET_EMAIL = "target@example.com"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_usage_status(
|
||||
daily_used: int = 500_000, weekly_used: int = 3_000_000
|
||||
) -> CoPilotUsageStatus:
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
now = datetime.now(UTC)
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used, limit=2_500_000, resets_at=now + timedelta(hours=6)
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used, limit=12_500_000, resets_at=now + timedelta(days=3)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _patch_rate_limit_deps(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
daily_used: int = 500_000,
|
||||
weekly_used: int = 3_000_000,
|
||||
):
|
||||
"""Patch the common rate-limit + user-lookup dependencies."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(daily_used=daily_used, weekly_used=weekly_used),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting rate limit and usage for a user."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"get_rate_limit",
|
||||
)
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test looking up rate limits via email instead of user_id."""
|
||||
_patch_rate_limit_deps(mocker, target_user_id)
|
||||
|
||||
mock_user = SimpleNamespace(id=target_user_id, email=_TARGET_EMAIL)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": _TARGET_EMAIL})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] == _TARGET_EMAIL
|
||||
assert data["daily_token_limit"] == 2_500_000
|
||||
|
||||
|
||||
def test_get_rate_limit_by_email_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test that looking up a non-existent email returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_by_email",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"email": "nobody@example.com"})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_rate_limit_no_params() -> None:
|
||||
"""Test that omitting both user_id and email returns 400."""
|
||||
response = client.get("/admin/rate_limit")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_only(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting only daily usage (default behaviour)."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=3_000_000)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_only",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test resetting both daily and weekly usage."""
|
||||
mock_reset = mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
_patch_rate_limit_deps(mocker, target_user_id, daily_used=0, weekly_used=0)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id, "reset_weekly": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
"reset_user_usage_daily_and_weekly",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_user_usage_redis_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that Redis failure on reset returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.reset_user_usage",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Redis connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": target_user_id},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_get_rate_limit_email_lookup_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that failing to resolve a user email degrades gracefully."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_mock_usage_status(),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection lost"),
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["user_email"] is None
|
||||
|
||||
|
||||
def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that rate limit admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/reset",
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -7,6 +7,8 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
@@ -132,3 +134,40 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions/{store_listing_version_id}/preview",
|
||||
summary="Admin Preview Submission Listing",
|
||||
)
|
||||
async def admin_preview_submission(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""
|
||||
Preview a marketplace submission as it would appear on the listing page.
|
||||
Bypasses the APPROVED-only StoreAgent view so admins can preview pending
|
||||
submissions before approving.
|
||||
"""
|
||||
return await store_db.get_store_agent_details_as_admin(store_listing_version_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/add-to-library",
|
||||
summary="Admin Add Pending Agent to Library",
|
||||
status_code=201,
|
||||
)
|
||||
async def admin_add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add a pending marketplace agent to the admin's library for review.
|
||||
Uses admin-level access to bypass marketplace APPROVED-only checks.
|
||||
|
||||
The builder can load the graph because get_graph() checks library
|
||||
membership as a fallback: "you added it, you keep it."
|
||||
"""
|
||||
return await library_db.add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
"""Tests for admin store routes and the bypass logic they depend on.
|
||||
|
||||
Tests are organized by what they protect:
|
||||
- SECRT-2162: get_graph_as_admin bypasses ownership/marketplace checks
|
||||
- SECRT-2167 security: admin endpoints reject non-admin users
|
||||
- SECRT-2167 bypass: preview queries StoreListingVersion (not StoreAgent view),
|
||||
and add-to-library uses get_graph_as_admin (not get_graph)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.graph import get_graph_as_admin
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .store_admin_routes import router as store_admin_router
|
||||
|
||||
# Shared constants
|
||||
ADMIN_USER_ID = "admin-user-id"
|
||||
CREATOR_USER_ID = "other-creator-id"
|
||||
GRAPH_ID = "test-graph-id"
|
||||
GRAPH_VERSION = 3
|
||||
SLV_ID = "test-store-listing-version-id"
|
||||
|
||||
|
||||
def _make_mock_graph(user_id: str = CREATOR_USER_ID) -> MagicMock:
|
||||
graph = MagicMock()
|
||||
graph.userId = user_id
|
||||
graph.id = GRAPH_ID
|
||||
graph.version = GRAPH_VERSION
|
||||
graph.Nodes = []
|
||||
return graph
|
||||
|
||||
|
||||
# ---- SECRT-2162: get_graph_as_admin bypasses ownership checks ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_access_pending_agent_not_owned() -> None:
|
||||
"""get_graph_as_admin must return a graph even when the admin doesn't own
|
||||
it and it's not APPROVED in the marketplace."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=False,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_download_pending_agent_with_subagents() -> None:
|
||||
"""get_graph_as_admin with for_export=True must call get_sub_graphs
|
||||
and pass sub_graphs to GraphModel.from_db."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_sub_graph = MagicMock(name="SubGraph")
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.data.graph.get_sub_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[mock_sub_graph],
|
||||
) as mock_get_sub,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.find_first = AsyncMock(return_value=mock_graph)
|
||||
|
||||
result = await get_graph_as_admin(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_get_sub.assert_awaited_once_with(mock_graph)
|
||||
mock_from_db.assert_called_once_with(
|
||||
graph=mock_graph,
|
||||
sub_graphs=[mock_sub_graph],
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
|
||||
# ---- SECRT-2167 security: admin endpoints reject non-admin users ---- #
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(store_admin_router)
|
||||
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def _not_found_handler(
|
||||
request: fastapi.Request, exc: NotFoundError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all route tests in this module."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_preview_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the preview endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_add_to_library_requires_admin(mock_jwt_user) -> None:
|
||||
"""Non-admin users must get 403 on the add-to-library endpoint."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
response = client.post(f"/admin/submissions/{SLV_ID}/add-to-library")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_preview_nonexistent_submission(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Preview of a nonexistent submission returns 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.store_admin_routes.store_db"
|
||||
".get_store_agent_details_as_admin",
|
||||
side_effect=NotFoundError("not found"),
|
||||
)
|
||||
response = client.get(f"/admin/submissions/{SLV_ID}/preview")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- SECRT-2167 bypass: verify the right data sources are used ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_queries_store_listing_version_not_store_agent() -> None:
|
||||
"""get_store_agent_details_as_admin must query StoreListingVersion
|
||||
directly (not the APPROVED-only StoreAgent view). This is THE test that
|
||||
prevents the bypass from being accidentally reverted."""
|
||||
from backend.api.features.store.db import get_store_agent_details_as_admin
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.id = SLV_ID
|
||||
mock_slv.name = "Test Agent"
|
||||
mock_slv.subHeading = "Short desc"
|
||||
mock_slv.description = "Long desc"
|
||||
mock_slv.videoUrl = None
|
||||
mock_slv.agentOutputDemoUrl = None
|
||||
mock_slv.imageUrls = ["https://example.com/img.png"]
|
||||
mock_slv.instructions = None
|
||||
mock_slv.categories = ["productivity"]
|
||||
mock_slv.version = 1
|
||||
mock_slv.agentGraphId = GRAPH_ID
|
||||
mock_slv.agentGraphVersion = GRAPH_VERSION
|
||||
mock_slv.updatedAt = datetime(2026, 3, 24, tzinfo=timezone.utc)
|
||||
mock_slv.recommendedScheduleCron = "0 9 * * *"
|
||||
|
||||
mock_listing = MagicMock()
|
||||
mock_listing.id = "listing-id"
|
||||
mock_listing.slug = "test-agent"
|
||||
mock_listing.activeVersionId = SLV_ID
|
||||
mock_listing.hasApprovedVersion = False
|
||||
mock_listing.CreatorProfile = MagicMock(username="creator", avatarUrl="")
|
||||
mock_slv.StoreListing = mock_listing
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models" ".StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch(
|
||||
"backend.api.features.store.db.prisma.models.StoreAgent.prisma",
|
||||
) as mock_store_agent_prisma,
|
||||
):
|
||||
mock_slv_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await get_store_agent_details_as_admin(SLV_ID)
|
||||
|
||||
# Verify it queried StoreListingVersion (not the APPROVED-only StoreAgent)
|
||||
mock_slv_prisma.return_value.find_unique.assert_awaited_once()
|
||||
await_args = mock_slv_prisma.return_value.find_unique.await_args
|
||||
assert await_args is not None
|
||||
assert await_args.kwargs["where"] == {"id": SLV_ID}
|
||||
|
||||
# Verify the APPROVED-only StoreAgent view was NOT touched
|
||||
mock_store_agent_prisma.assert_not_called()
|
||||
|
||||
# Verify the result has the right data
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_image == ["https://example.com/img.png"]
|
||||
assert result.has_approved_version is False
|
||||
assert result.runs == 0
|
||||
assert result.rating == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_admin_uses_get_graph_as_admin() -> None:
|
||||
"""resolve_graph_for_library(admin=True) must call get_graph_as_admin,
|
||||
not get_graph. This is THE test that prevents the add-to-library bypass
|
||||
from being accidentally reverted."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, ADMIN_USER_ID, admin=True)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_admin.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id=ADMIN_USER_ID
|
||||
)
|
||||
mock_regular.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_graph_regular_uses_get_graph() -> None:
|
||||
"""resolve_graph_for_library(admin=False) must call get_graph,
|
||||
not get_graph_as_admin. Ensures the non-admin path is preserved."""
|
||||
from backend.api.features.library._add_to_library import resolve_graph_for_library
|
||||
|
||||
mock_slv = MagicMock()
|
||||
mock_slv.AgentGraph = MagicMock(id=GRAPH_ID, version=GRAPH_VERSION)
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models"
|
||||
".StoreListingVersion.prisma",
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
".get_graph_as_admin",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_admin,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.graph_db.get_graph",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_graph_model,
|
||||
) as mock_regular,
|
||||
):
|
||||
mock_prisma.return_value.find_unique = AsyncMock(return_value=mock_slv)
|
||||
|
||||
result = await resolve_graph_for_library(SLV_ID, "regular-user-id", admin=False)
|
||||
|
||||
assert result is mock_graph_model
|
||||
mock_regular.assert_awaited_once_with(
|
||||
graph_id=GRAPH_ID, version=GRAPH_VERSION, user_id="regular-user-id"
|
||||
)
|
||||
mock_admin.assert_not_awaited()
|
||||
|
||||
|
||||
# ---- Library membership grants graph access (product decision) ---- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_member_can_view_pending_agent_in_builder() -> None:
|
||||
"""After adding a pending agent to their library, the user should be
|
||||
able to load the graph in the builder via get_graph()."""
|
||||
mock_graph = _make_mock_graph()
|
||||
mock_graph_model = MagicMock(name="GraphModel")
|
||||
mock_library_agent = MagicMock()
|
||||
mock_library_agent.AgentGraph = mock_graph
|
||||
|
||||
with (
|
||||
patch("backend.data.graph.AgentGraph.prisma") as mock_ag_prisma,
|
||||
patch(
|
||||
"backend.data.graph.StoreListingVersion.prisma",
|
||||
) as mock_slv_prisma,
|
||||
patch("backend.data.graph.LibraryAgent.prisma") as mock_lib_prisma,
|
||||
patch(
|
||||
"backend.data.graph.GraphModel.from_db",
|
||||
return_value=mock_graph_model,
|
||||
),
|
||||
):
|
||||
mock_ag_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_slv_prisma.return_value.find_first = AsyncMock(return_value=None)
|
||||
mock_lib_prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_library_agent
|
||||
)
|
||||
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
result = await get_graph(
|
||||
graph_id=GRAPH_ID,
|
||||
version=GRAPH_VERSION,
|
||||
user_id=ADMIN_USER_ID,
|
||||
)
|
||||
|
||||
assert result is mock_graph_model, "Library membership should grant graph access"
|
||||
@@ -4,14 +4,12 @@ from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -24,6 +22,7 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
@@ -271,7 +270,7 @@ async def _build_cached_search_results(
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
block_results, block_total, integration_total = await _text_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
@@ -383,117 +382,75 @@ def _collect_block_results(
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
async def _text_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
Search blocks using in-memory text matching over the block registry.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
All blocks are already loaded in memory, so this is fast and reliable
|
||||
regardless of whether OpenAI embeddings are available.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
- Base: text relevance via _score_primary_fields, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
return results, 0, 0
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
all_results, _, _ = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
for item in all_results:
|
||||
block_info = item.item
|
||||
assert isinstance(block_info, BlockInfo)
|
||||
name = split_camelcase(block_info.name).lower()
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Build rich description including input field descriptions,
|
||||
# matching the searchable text that the embedding pipeline uses
|
||||
desc_parts = [block_info.description or ""]
|
||||
block_cls = all_blocks.get(block_info.id)
|
||||
if block_cls is not None:
|
||||
block: AnyBlockSchema = block_cls()
|
||||
desc_parts += [
|
||||
f"{f}: {info.description}"
|
||||
for f, info in block.input_schema.model_fields.items()
|
||||
if info.description
|
||||
]
|
||||
description = " ".join(desc_parts).lower()
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
if block_cls is not None and _matches_llm_model(
|
||||
block_cls().input_schema, normalized_query
|
||||
):
|
||||
score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
if score >= MIN_SCORE_FOR_FILTERED_RESULTS:
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=item.filter_type,
|
||||
score=score + BLOCK_SCORE_BOOST,
|
||||
sort_key=name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
block_count = sum(1 for r in results if r.filter_type == "blocks")
|
||||
integration_count = sum(1 for r in results if r.filter_type == "integrations")
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -20,6 +20,7 @@ from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionMetadata,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
@@ -27,6 +28,18 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
acquire_reset_lock,
|
||||
check_rate_limit,
|
||||
get_daily_reset_count,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
increment_daily_reset_count,
|
||||
release_reset_lock,
|
||||
reset_daily_usage,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -53,9 +66,16 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import InsufficientBalanceError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -63,8 +83,6 @@ _UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -95,12 +113,25 @@ class StreamChatRequest(BaseModel):
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
``dry_run`` is a **top-level** field — do not nest it inside ``metadata``.
|
||||
Extra/unknown fields are rejected (422) to prevent silent mis-use.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
@@ -119,6 +150,9 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
metadata: ChatSessionMetadata = ChatSessionMetadata()
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -206,7 +240,7 @@ async def list_sessions(
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
"Failed to fetch processing status from Redis; defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
@@ -228,7 +262,8 @@ async def list_sessions(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
request: CreateSessionRequest | None = None,
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -237,22 +272,28 @@ async def create_session(
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
request: Optional request body. When provided, ``dry_run=True``
|
||||
forces run_block and run_agent calls to use dry-run simulation.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
dry_run = request.dry_run if request else False
|
||||
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
f"{', dry_run=True' if dry_run else ''}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -347,7 +388,7 @@ async def update_session_title_route(
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
@@ -388,6 +429,10 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -395,6 +440,202 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
metadata=session.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/usage",
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitResetResponse(BaseModel):
|
||||
"""Response from resetting the daily rate limit."""
|
||||
|
||||
success: bool
|
||||
credits_charged: int = Field(description="Credits charged (in cents)")
|
||||
remaining_balance: int = Field(description="Credit balance after charge (in cents)")
|
||||
usage: CoPilotUsageStatus = Field(description="Updated usage status after reset")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/reset",
|
||||
status_code=200,
|
||||
responses={
|
||||
400: {
|
||||
"description": "Bad Request (feature disabled or daily limit not reached)"
|
||||
},
|
||||
402: {"description": "Payment Required (insufficient credits)"},
|
||||
429: {
|
||||
"description": "Too Many Requests (max daily resets exceeded or reset in progress)"
|
||||
},
|
||||
503: {
|
||||
"description": "Service Unavailable (Redis reset failed; credits refunded or support needed)"
|
||||
},
|
||||
},
|
||||
)
|
||||
async def reset_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> RateLimitResetResponse:
|
||||
"""Reset the daily CoPilot rate limit by spending credits.
|
||||
|
||||
Allows users who have hit their daily token limit to spend credits
|
||||
to reset their daily usage counter and continue working.
|
||||
Returns 400 if the feature is disabled or the user is not over the limit.
|
||||
Returns 402 if the user has insufficient credits.
|
||||
"""
|
||||
cost = config.rate_limit_reset_cost
|
||||
if cost <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available.",
|
||||
)
|
||||
|
||||
if not settings.config.enable_credit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
if daily_limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No daily limit is configured — nothing to reset.",
|
||||
)
|
||||
|
||||
# Check max daily resets. get_daily_reset_count returns None when Redis
|
||||
# is unavailable; reject the reset in that case to prevent unlimited
|
||||
# free resets when the counter store is down.
|
||||
reset_count = await get_daily_reset_count(user_id)
|
||||
if reset_count is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Unable to verify reset eligibility — please try again later.",
|
||||
)
|
||||
if config.max_daily_resets > 0 and reset_count >= config.max_daily_resets:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"You've used all {config.max_daily_resets} resets for today.",
|
||||
)
|
||||
|
||||
# Acquire a per-user lock to prevent TOCTOU races (concurrent resets).
|
||||
if not await acquire_reset_lock(user_id):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="A reset is already in progress. Please try again.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="You have not reached your daily limit yet.",
|
||||
)
|
||||
|
||||
# If the weekly limit is also exhausted, resetting the daily counter
|
||||
# won't help — the user would still be blocked by the weekly limit.
|
||||
if weekly_limit > 0 and usage_status.weekly.used >= weekly_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Your weekly limit is also reached. Resetting the daily limit won't help.",
|
||||
)
|
||||
|
||||
# Charge credits.
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
try:
|
||||
remaining = await credit_model.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
reason="CoPilot daily rate limit reset",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient credits to reset your rate limit.",
|
||||
) from e
|
||||
|
||||
# Reset daily usage in Redis. If this fails, refund the credits
|
||||
# so the user is not charged for a service they did not receive.
|
||||
if not await reset_daily_usage(user_id, daily_token_limit=daily_limit):
|
||||
# Compensate: refund the charged credits.
|
||||
refunded = False
|
||||
try:
|
||||
await credit_model.top_up_credits(user_id, cost)
|
||||
refunded = True
|
||||
logger.warning(
|
||||
"Refunded %d credits to user %s after Redis reset failure",
|
||||
cost,
|
||||
user_id[:8],
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"CRITICAL: Failed to refund %d credits to user %s "
|
||||
"after Redis reset failure — manual intervention required",
|
||||
cost,
|
||||
user_id[:8],
|
||||
exc_info=True,
|
||||
)
|
||||
if refunded:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed — please try again later. "
|
||||
"Your credits have not been charged.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limit reset failed and the automatic refund "
|
||||
"also failed. Please contact support for assistance.",
|
||||
)
|
||||
|
||||
# Track the reset count for daily cap enforcement.
|
||||
await increment_daily_reset_count(user_id)
|
||||
finally:
|
||||
await release_reset_lock(user_id)
|
||||
|
||||
# Return updated usage status.
|
||||
updated_usage = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
success=True,
|
||||
credits_charged=cost,
|
||||
remaining_balance=remaining,
|
||||
usage=updated_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -404,7 +645,7 @@ async def get_session(
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
@@ -449,7 +690,7 @@ async def cancel_session_task(
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
@@ -466,7 +707,7 @@ async def stream_chat_post(
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
user_id: Authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
@@ -475,9 +716,7 @@ async def stream_chat_post(
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
@@ -495,6 +734,22 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
@@ -729,7 +984,7 @@ async def stream_chat_post(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Resume an active stream for a session.
|
||||
@@ -853,6 +1108,47 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedTheme(BaseModel):
|
||||
"""A themed group of suggested prompts."""
|
||||
|
||||
name: str
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts grouped by theme."""
|
||||
|
||||
themes: list[SuggestedTheme]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts grouped by theme.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns empty themes list if no custom
|
||||
prompts are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None or not understanding.suggested_prompts:
|
||||
return SuggestedPromptsResponse(themes=[])
|
||||
|
||||
themes = [
|
||||
SuggestedTheme(name=name, prompts=prompts)
|
||||
for name, prompts in understanding.suggested_prompts.items()
|
||||
]
|
||||
return SuggestedPromptsResponse(themes=themes)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@@ -901,7 +1197,7 @@ async def health_check() -> dict:
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
session = await create_chat_session(health_check_user_id, dry_run=False)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -249,3 +250,279 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
# Ensure the rate-limit branch is entered by setting a non-zero limit.
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
resets_at = datetime.now(UTC) + timedelta(days=3)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("weekly", resets_at),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"].lower()
|
||||
assert "weekly" in detail
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded(
|
||||
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"]
|
||||
assert "2h" in detail
|
||||
assert "Resets in" in detail
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_rejects_unauthenticated_request() -> None:
|
||||
"""GET /usage should return 401 when no valid JWT is provided."""
|
||||
unauthenticated_app = fastapi.FastAPI()
|
||||
unauthenticated_app.include_router(chat_routes.router)
|
||||
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
|
||||
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_themes(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with themed prompts gets them back as themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {
|
||||
"Learn": ["L1", "L2"],
|
||||
"Create": ["C1"],
|
||||
}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "themes" in data
|
||||
themes_by_name = {t["name"]: t["prompts"] for t in data["themes"]}
|
||||
assert themes_by_name["Learn"] == ["L1", "L2"]
|
||||
assert themes_by_name["Create"] == ["C1"]
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty themes list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but empty prompts gets empty themes list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = {}
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"themes": []}
|
||||
|
||||
|
||||
# ─── Create session: dry_run contract ─────────────────────────────────
|
||||
|
||||
|
||||
def _mock_create_chat_session(mocker: pytest_mock.MockerFixture):
|
||||
"""Mock create_chat_session to return a fake session."""
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
async def _fake_create(user_id: str, *, dry_run: bool):
|
||||
return ChatSession.new(user_id, dry_run=dry_run)
|
||||
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.create_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_fake_create,
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_dry_run_true(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"dry_run": true}`` sets metadata.dry_run to True."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is True
|
||||
|
||||
|
||||
def test_create_session_dry_run_default_false(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Empty body defaults dry_run to False."""
|
||||
_mock_create_chat_session(mocker)
|
||||
|
||||
response = client.post("/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["metadata"]["dry_run"] is False
|
||||
|
||||
|
||||
def test_create_session_rejects_nested_metadata(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Sending ``{"metadata": {"dry_run": true}}`` must return 422, not silently
|
||||
default to ``dry_run=False``. This guards against the common mistake of
|
||||
nesting dry_run inside metadata instead of providing it at the top level."""
|
||||
response = client.post(
|
||||
"/sessions",
|
||||
json={"metadata": {"dry_run": True}},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Override session-scoped fixtures so unit tests run without the server."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup():
|
||||
yield
|
||||
@@ -34,16 +34,21 @@ from backend.data.model import (
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
is_sdk_default,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.credentials_store import (
|
||||
is_system_credential,
|
||||
provider_matches,
|
||||
)
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.managed_credentials import ensure_managed_credentials
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -109,6 +114,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
is_managed: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -138,6 +144,19 @@ class CredentialsMetaResponse(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def to_meta_response(cred: Credentials) -> CredentialsMetaResponse:
|
||||
return CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
is_managed=cred.is_managed,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
@@ -204,34 +223,20 @@ async def callback(
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
@router.get("/credentials", summary="List Credentials")
|
||||
async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
# Fire-and-forget: provision missing managed credentials in the background.
|
||||
# The credential appears on the next page load; listing is never blocked.
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -242,19 +247,11 @@ async def list_credentials_by_provider(
|
||||
],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
asyncio.create_task(ensure_managed_credentials(user_id, creds_manager.store))
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
to_meta_response(cred) for cred in credentials if not is_sdk_default(cred.id)
|
||||
]
|
||||
|
||||
|
||||
@@ -267,18 +264,21 @@ async def get_credential(
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if credential.provider != provider:
|
||||
if not provider_matches(credential.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
return credential
|
||||
return to_meta_response(credential)
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
|
||||
@@ -288,16 +288,22 @@ async def create_credentials(
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
) -> CredentialsMetaResponse:
|
||||
if is_sdk_default(credentials.id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot create credentials with a reserved ID",
|
||||
)
|
||||
credentials.provider = provider
|
||||
try:
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to store credentials")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
detail="Failed to store credentials",
|
||||
)
|
||||
return credentials
|
||||
return to_meta_response(credentials)
|
||||
|
||||
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
@@ -332,15 +338,29 @@ async def delete_credentials(
|
||||
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
||||
] = False,
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
if is_sdk_default(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if is_system_credential(cred_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="System-managed credentials cannot be deleted",
|
||||
)
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
if creds.provider != provider:
|
||||
if not provider_matches(creds.provider, provider):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
detail="Credentials not found",
|
||||
)
|
||||
if creds.is_managed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="AutoGPT-managed credentials cannot be deleted",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,570 @@
|
||||
"""Tests for credentials API security: no secret leakage, SDK defaults filtered."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.integrations.router import router
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "test-user-id"
|
||||
|
||||
|
||||
def _make_api_key_cred(cred_id: str = "cred-123", provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My API Key",
|
||||
api_key=SecretStr("sk-secret-key-value"),
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_cred(cred_id: str = "cred-456", provider: str = "github"):
|
||||
return OAuth2Credentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My OAuth",
|
||||
access_token=SecretStr("ghp_secret_token"),
|
||||
refresh_token=SecretStr("ghp_refresh_secret"),
|
||||
scopes=["repo", "user"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
|
||||
def _make_user_password_cred(cred_id: str = "cred-789", provider: str = "openai"):
|
||||
return UserPasswordCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="My Login",
|
||||
username=SecretStr("admin"),
|
||||
password=SecretStr("s3cret-pass"),
|
||||
)
|
||||
|
||||
|
||||
def _make_host_scoped_cred(cred_id: str = "cred-host", provider: str = "openai"):
|
||||
return HostScopedCredentials(
|
||||
id=cred_id,
|
||||
provider=provider,
|
||||
title="Host Cred",
|
||||
host="https://api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer top-secret")},
|
||||
)
|
||||
|
||||
|
||||
def _make_sdk_default_cred(provider: str = "openai"):
|
||||
return APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
title=f"{provider} (default)",
|
||||
api_key=SecretStr("sk-platform-secret-key"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestGetCredentialReturnsMetaOnly:
|
||||
"""GET /{provider}/credentials/{cred_id} must not return secrets."""
|
||||
|
||||
def test_api_key_credential_no_secret(self):
|
||||
cred = _make_api_key_cred()
|
||||
with (
|
||||
patch.object(router, "dependencies", []),
|
||||
patch("backend.api.features.integrations.router.creds_manager") as mock_mgr,
|
||||
):
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-123"
|
||||
assert data["provider"] == "openai"
|
||||
assert data["type"] == "api_key"
|
||||
assert "api_key" not in data
|
||||
assert "sk-secret-key-value" not in str(data)
|
||||
|
||||
def test_oauth2_credential_no_secret(self):
|
||||
cred = _make_oauth2_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-456")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-456"
|
||||
assert data["scopes"] == ["repo", "user"]
|
||||
assert data["username"] == "testuser"
|
||||
assert "access_token" not in data
|
||||
assert "refresh_token" not in data
|
||||
assert "ghp_" not in str(data)
|
||||
|
||||
def test_user_password_credential_no_secret(self):
|
||||
cred = _make_user_password_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-789")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-789"
|
||||
assert "password" not in data
|
||||
assert "username" not in data or data["username"] is None
|
||||
assert "s3cret-pass" not in str(data)
|
||||
assert "admin" not in str(data)
|
||||
|
||||
def test_host_scoped_credential_no_secret(self):
|
||||
cred = _make_host_scoped_cred()
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/openai/credentials/cred-host")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == "cred-host"
|
||||
assert data["host"] == "https://api.example.com"
|
||||
assert "headers" not in data
|
||||
assert "top-secret" not in str(data)
|
||||
|
||||
def test_get_credential_wrong_provider_returns_404(self):
|
||||
"""Provider mismatch should return generic 404, not leak credential existence."""
|
||||
cred = _make_api_key_cred(provider="openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock(return_value=cred)
|
||||
resp = client.get("/github/credentials/cred-123")
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.json()["detail"] == "Credentials not found"
|
||||
|
||||
def test_list_credentials_no_secrets(self):
|
||||
"""List endpoint must not leak secrets in any credential."""
|
||||
creds = [_make_api_key_cred(), _make_oauth2_cred()]
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=creds)
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
raw = str(resp.json())
|
||||
assert "sk-secret-key-value" not in raw
|
||||
assert "ghp_secret_token" not in raw
|
||||
assert "ghp_refresh_secret" not in raw
|
||||
|
||||
|
||||
class TestSdkDefaultCredentialsNotAccessible:
|
||||
"""SDK default credentials (ID ending in '-default') must be hidden."""
|
||||
|
||||
def test_get_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.get = AsyncMock()
|
||||
resp = client.get("/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.get.assert_not_called()
|
||||
|
||||
def test_list_credentials_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[user_cred, sdk_cred])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_list_by_provider_excludes_sdk_defaults(self):
|
||||
user_cred = _make_api_key_cred()
|
||||
sdk_cred = _make_sdk_default_cred("openai")
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[user_cred, sdk_cred]
|
||||
)
|
||||
resp = client.get("/openai/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
ids = [c["id"] for c in data]
|
||||
assert "cred-123" in ids
|
||||
assert "openai-default" not in ids
|
||||
|
||||
def test_delete_sdk_default_returns_404(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock()
|
||||
resp = client.request("DELETE", "/openai/credentials/openai-default")
|
||||
|
||||
assert resp.status_code == 404
|
||||
mock_mgr.store.get_creds_by_id.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateCredentialNoSecretInResponse:
|
||||
"""POST /{provider}/credentials must not return secrets."""
|
||||
|
||||
def test_create_api_key_no_secret_in_response(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "new-cred",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "New Key",
|
||||
"api_key": "sk-newsecret",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["id"] == "new-cred"
|
||||
assert "api_key" not in data
|
||||
assert "sk-newsecret" not in str(data)
|
||||
|
||||
def test_create_with_sdk_default_id_rejected(self):
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.create = AsyncMock()
|
||||
resp = client.post(
|
||||
"/openai/credentials",
|
||||
json={
|
||||
"id": "openai-default",
|
||||
"provider": "openai",
|
||||
"type": "api_key",
|
||||
"title": "Sneaky",
|
||||
"api_key": "sk-evil",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
mock_mgr.create.assert_not_called()
|
||||
|
||||
|
||||
class TestManagedCredentials:
|
||||
"""AutoGPT-managed credentials cannot be deleted by users."""
|
||||
|
||||
def test_delete_is_managed_returns_403(self):
|
||||
cred = APIKeyCredentials(
|
||||
id="managed-cred-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-managed-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_creds_by_id = AsyncMock(return_value=cred)
|
||||
resp = client.request("DELETE", "/agent_mail/credentials/managed-cred-1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
assert "AutoGPT-managed" in resp.json()["detail"]
|
||||
|
||||
def test_list_credentials_includes_is_managed_field(self):
|
||||
managed = APIKeyCredentials(
|
||||
id="managed-1",
|
||||
provider="agent_mail",
|
||||
title="AgentMail (managed)",
|
||||
api_key=SecretStr("sk-key"),
|
||||
is_managed=True,
|
||||
)
|
||||
regular = APIKeyCredentials(
|
||||
id="regular-1",
|
||||
provider="openai",
|
||||
title="My Key",
|
||||
api_key=SecretStr("sk-key"),
|
||||
)
|
||||
with patch(
|
||||
"backend.api.features.integrations.router.creds_manager"
|
||||
) as mock_mgr:
|
||||
mock_mgr.store.get_all_creds = AsyncMock(return_value=[managed, regular])
|
||||
resp = client.get("/credentials")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
managed_cred = next(c for c in data if c["id"] == "managed-1")
|
||||
regular_cred = next(c for c in data if c["id"] == "regular-1")
|
||||
assert managed_cred["is_managed"] is True
|
||||
assert regular_cred["is_managed"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Managed credential provisioning infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_managed_cred(
|
||||
provider: str = "agent_mail", pod_id: str = "pod-abc"
|
||||
) -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="managed-auto",
|
||||
provider=provider,
|
||||
title="AgentMail (managed by AutoGPT)",
|
||||
api_key=SecretStr("sk-pod-key"),
|
||||
is_managed=True,
|
||||
metadata={"pod_id": pod_id},
|
||||
)
|
||||
|
||||
|
||||
def _make_store_mock(**kwargs) -> MagicMock:
|
||||
"""Create a store mock with a working async ``locks()`` context manager."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_locked(key):
|
||||
yield
|
||||
|
||||
locks_obj = MagicMock()
|
||||
locks_obj.locked = _noop_locked
|
||||
|
||||
store = MagicMock(**kwargs)
|
||||
store.locks = AsyncMock(return_value=locks_obj)
|
||||
return store
|
||||
|
||||
|
||||
class TestEnsureManagedCredentials:
|
||||
"""Unit tests for the ensure/cleanup helpers in managed_credentials.py."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provisions_when_missing(self):
|
||||
"""Provider.provision() is called when no managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(return_value=cred)
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
store.add_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_awaited_once_with("user-1")
|
||||
store.add_managed_credential.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_already_exists(self):
|
||||
"""Provider.provision() is NOT called when managed credential exists."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=True)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_unavailable(self):
|
||||
"""Provider.provision() is NOT called when provider is not available."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=False)
|
||||
provider.provision = AsyncMock()
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock()
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
provider.provision.assert_not_awaited()
|
||||
store.has_managed_credential.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provision_failure_does_not_propagate(self):
|
||||
"""A failed provision is logged but does not raise."""
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
_provisioned_users,
|
||||
ensure_managed_credentials,
|
||||
)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "test_provider"
|
||||
provider.is_available = AsyncMock(return_value=True)
|
||||
provider.provision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = _make_store_mock()
|
||||
store.has_managed_credential = AsyncMock(return_value=False)
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["test_provider"] = provider
|
||||
_provisioned_users.pop("user-1", None)
|
||||
try:
|
||||
await ensure_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
_provisioned_users.pop("user-1", None)
|
||||
|
||||
# No exception raised — provisioning failure is swallowed.
|
||||
|
||||
|
||||
class TestCleanupManagedCredentials:
|
||||
"""Unit tests for cleanup_managed_credentials."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_deprovision_for_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_awaited_once_with("user-1", cred)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_non_managed_creds(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
regular = _make_api_key_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "openai"
|
||||
provider.deprovision = AsyncMock()
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[regular])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["openai"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
provider.deprovision.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprovision_failure_does_not_propagate(self):
|
||||
from backend.integrations.managed_credentials import (
|
||||
_PROVIDERS,
|
||||
cleanup_managed_credentials,
|
||||
)
|
||||
|
||||
cred = _make_managed_cred()
|
||||
provider = MagicMock()
|
||||
provider.provider_name = "agent_mail"
|
||||
provider.deprovision = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
store = MagicMock()
|
||||
store.get_all_creds = AsyncMock(return_value=[cred])
|
||||
|
||||
saved = dict(_PROVIDERS)
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS["agent_mail"] = provider
|
||||
try:
|
||||
await cleanup_managed_credentials("user-1", store)
|
||||
finally:
|
||||
_PROVIDERS.clear()
|
||||
_PROVIDERS.update(saved)
|
||||
|
||||
# No exception raised — cleanup failure is swallowed.
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Shared logic for adding store agents to a user's library.
|
||||
|
||||
Both `add_store_agent_to_library` and `add_store_agent_to_library_as_admin`
|
||||
delegate to these helpers so the duplication-prone create/restore/dedup
|
||||
logic lives in exactly one place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.data.graph as graph_db
|
||||
from backend.data.graph import GraphModel, GraphSettings
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def resolve_graph_for_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
admin: bool,
|
||||
) -> GraphModel:
|
||||
"""Look up a StoreListingVersion and resolve its graph.
|
||||
|
||||
When ``admin=True``, uses ``get_graph_as_admin`` to bypass the marketplace
|
||||
APPROVED-only check. Otherwise uses the regular ``get_graph``.
|
||||
"""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
if not slv or not slv.AgentGraph:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
ag = slv.AgentGraph
|
||||
if admin:
|
||||
graph_model = await graph_db.get_graph_as_admin(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=ag.id, version=ag.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph_model:
|
||||
raise NotFoundError(f"Graph #{ag.id} v{ag.version} not found or accessible")
|
||||
return graph_model
|
||||
|
||||
|
||||
async def add_graph_to_library(
|
||||
store_listing_version_id: str,
|
||||
graph_model: GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Check existing / restore soft-deleted / create new LibraryAgent.
|
||||
|
||||
Uses a create-then-catch-UniqueViolationError-then-update pattern on
|
||||
the (userId, agentGraphId, agentGraphVersion) composite unique constraint.
|
||||
This is more robust than ``upsert`` because Prisma's upsert atomicity
|
||||
guarantees are not well-documented for all versions.
|
||||
"""
|
||||
settings_json = SafeJson(GraphSettings.from_graph(graph_model).model_dump())
|
||||
_include = library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
try:
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_model.id,
|
||||
"version": graph_model.version,
|
||||
}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
# Already exists — update to restore if previously soft-deleted/archived
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().update(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_model.id,
|
||||
"agentGraphVersion": graph_model.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"settings": settings_json,
|
||||
},
|
||||
include=_include,
|
||||
)
|
||||
if added_agent is None:
|
||||
raise NotFoundError(
|
||||
f"LibraryAgent for graph #{graph_model.id} "
|
||||
f"v{graph_model.version} not found after UniqueViolationError"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Added graph #{graph_model.id} v{graph_model.version} "
|
||||
f"for store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.errors
|
||||
import pytest
|
||||
|
||||
from ._add_to_library import add_graph_to_library
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_create_new_agent() -> None:
|
||||
"""When no matching LibraryAgent exists, create inserts a new one."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
created_agent = MagicMock(name="CreatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(return_value=created_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(created_agent)
|
||||
# Verify create was called with correct data
|
||||
create_call = mock_prisma.return_value.create.call_args
|
||||
create_data = create_call.kwargs["data"]
|
||||
assert create_data["User"] == {"connect": {"id": "user-id"}}
|
||||
assert create_data["AgentGraph"] == {
|
||||
"connect": {"graphVersionId": {"id": "graph-id", "version": 2}}
|
||||
}
|
||||
assert create_data["isCreatedByUser"] is False
|
||||
assert create_data["useGraphIsActiveVersion"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_to_library_unique_violation_updates_existing() -> None:
|
||||
"""UniqueViolationError on create falls back to update."""
|
||||
graph_model = MagicMock(id="graph-id", version=2, nodes=[])
|
||||
updated_agent = MagicMock(name="UpdatedLibraryAgent")
|
||||
converted_agent = MagicMock(name="ConvertedLibraryAgent")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.prisma.models.LibraryAgent.prisma"
|
||||
) as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library._add_to_library.library_model.LibraryAgent.from_db",
|
||||
return_value=converted_agent,
|
||||
) as mock_from_db,
|
||||
):
|
||||
mock_prisma.return_value.create = AsyncMock(
|
||||
side_effect=prisma.errors.UniqueViolationError(
|
||||
MagicMock(), message="unique constraint"
|
||||
)
|
||||
)
|
||||
mock_prisma.return_value.update = AsyncMock(return_value=updated_agent)
|
||||
|
||||
result = await add_graph_to_library("slv-id", graph_model, "user-id")
|
||||
|
||||
assert result is converted_agent
|
||||
mock_from_db.assert_called_once_with(updated_agent)
|
||||
# Verify update was called with correct where and data
|
||||
update_call = mock_prisma.return_value.update.call_args
|
||||
assert update_call.kwargs["where"] == {
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "user-id",
|
||||
"agentGraphId": "graph-id",
|
||||
"agentGraphVersion": 2,
|
||||
}
|
||||
}
|
||||
update_data = update_call.kwargs["data"]
|
||||
assert update_data["isDeleted"] is False
|
||||
assert update_data["isArchived"] is False
|
||||
@@ -336,12 +336,15 @@ async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
include_archived: bool = False,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if not include_archived:
|
||||
filter["isArchived"] = False
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
@@ -433,32 +436,53 @@ async def create_library_agent(
|
||||
async with transaction() as tx:
|
||||
library_agents = await asyncio.gather(
|
||||
*(
|
||||
prisma.models.LibraryAgent.prisma(tx).create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
prisma.models.LibraryAgent.prisma(tx).upsert(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph_entry.id,
|
||||
"agentGraphVersion": graph_entry.version,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": prisma.types.LibraryAgentCreateInput(
|
||||
isCreatedByUser=(user_id == graph.user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
"id": graph_entry.id,
|
||||
"version": graph_entry.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
"update": {
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"useGraphIsActiveVersion": True,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
@@ -582,7 +606,9 @@ async def update_graph_in_library(
|
||||
|
||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||
|
||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||
library_agent = await get_library_agent_by_graph_id(
|
||||
user_id, created_graph.id, include_archived=True
|
||||
)
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||
|
||||
@@ -818,92 +844,38 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Adds a marketplace agent to the user’s library.
|
||||
|
||||
See also: `add_store_agent_to_library_as_admin()` which uses
|
||||
`get_graph_as_admin` to bypass marketplace status checks for admin review.
|
||||
"""
|
||||
Adds an agent from a store listing version to the user's library if they don't already have it.
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
Args:
|
||||
store_listing_version_id: The ID of the store listing version containing the agent.
|
||||
user_id: The user’s library to which the agent is being added.
|
||||
|
||||
Returns:
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version #{store_listing_version_id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"AgentGraph": True}
|
||||
)
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=False
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
graph = store_listing_version.AgentGraph
|
||||
|
||||
# Convert to GraphModel to check for HITL blocks
|
||||
graph_model = await graph_db.get_graph(
|
||||
graph_id=graph.id,
|
||||
version=graph.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=False,
|
||||
async def add_store_agent_to_library_as_admin(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> library_model.LibraryAgent:
|
||||
"""Admin variant that uses `get_graph_as_admin` to bypass marketplace
|
||||
APPROVED-only checks, allowing admins to add pending agents for review."""
|
||||
from ._add_to_library import add_graph_to_library, resolve_graph_for_library
|
||||
|
||||
logger.warning(
|
||||
f"ADMIN adding agent from store listing version "
|
||||
f"#{store_listing_version_id} to library for user #{user_id}"
|
||||
)
|
||||
if not graph_model:
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
# Check if user already has this agent (non-deleted)
|
||||
if existing := await get_library_agent_by_graph_id(
|
||||
user_id, graph.id, graph.version
|
||||
):
|
||||
return existing
|
||||
|
||||
# Check for soft-deleted version and restore it
|
||||
deleted_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": user_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
}
|
||||
},
|
||||
graph_model = await resolve_graph_for_library(
|
||||
store_listing_version_id, user_id, admin=True
|
||||
)
|
||||
if deleted_agent and deleted_agent.isDeleted:
|
||||
return await update_library_agent(deleted_agent.id, user_id, is_deleted=False)
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(GraphSettings.from_graph(graph_model).model_dump()),
|
||||
},
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
f"for store listing version #{store_listing_version.id} "
|
||||
f"to library for user #{user_id}"
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(added_agent)
|
||||
return await add_graph_to_library(store_listing_version_id, graph_model, user_id)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma.enums
|
||||
import prisma.models
|
||||
@@ -85,10 +87,6 @@ async def test_get_library_agents(mocker):
|
||||
async def test_add_agent_to_library(mocker):
|
||||
await connect()
|
||||
|
||||
# Mock the transaction context
|
||||
mock_transaction = mocker.patch("backend.api.features.library.db.transaction")
|
||||
mock_transaction.return_value.__aenter__ = mocker.AsyncMock(return_value=None)
|
||||
mock_transaction.return_value.__aexit__ = mocker.AsyncMock(return_value=None)
|
||||
# Mock data
|
||||
mock_store_listing_data = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
@@ -143,15 +141,18 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
|
||||
# Mock graph_db.get_graph function that's called to check for HITL blocks
|
||||
mock_graph_db = mocker.patch("backend.api.features.library.db.graph_db")
|
||||
# Mock graph_db.get_graph function that's called in resolve_graph_for_library
|
||||
# (lives in _add_to_library.py after refactor, not db.py)
|
||||
mock_graph_db = mocker.patch(
|
||||
"backend.api.features.library._add_to_library.graph_db"
|
||||
)
|
||||
mock_graph_model = mocker.Mock()
|
||||
mock_graph_model.id = "agent1"
|
||||
mock_graph_model.version = 1
|
||||
mock_graph_model.nodes = (
|
||||
[]
|
||||
) # Empty list so _has_human_in_the_loop_blocks returns False
|
||||
@@ -170,37 +171,27 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
assert create_call_args is not None
|
||||
|
||||
# Verify the main structure
|
||||
expected_data = {
|
||||
# Verify the create data structure
|
||||
create_data = create_call_args.kwargs["data"]
|
||||
expected_create = {
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {"connect": {"graphVersionId": {"id": "agent1", "version": 1}}},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
}
|
||||
|
||||
actual_data = create_call_args[1]["data"]
|
||||
# Check that all expected fields are present
|
||||
for key, value in expected_data.items():
|
||||
assert actual_data[key] == value
|
||||
for key, value in expected_create.items():
|
||||
assert create_data[key] == value
|
||||
|
||||
# Check that settings field is present and is a SafeJson object
|
||||
assert "settings" in actual_data
|
||||
assert hasattr(actual_data["settings"], "__class__") # Should be a SafeJson object
|
||||
assert "settings" in create_data
|
||||
assert hasattr(create_data["settings"], "__class__") # Should be a SafeJson object
|
||||
|
||||
# Check include parameter
|
||||
assert create_call_args[1]["include"] == library_agent_include(
|
||||
assert create_call_args.kwargs["include"] == library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
)
|
||||
|
||||
@@ -224,3 +215,141 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_excludes_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id("test-user", "agent1", 7)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agent_by_graph_id_can_include_archived(mocker):
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
|
||||
result = await db.get_library_agent_by_graph_id(
|
||||
"test-user",
|
||||
"agent1",
|
||||
7,
|
||||
include_archived=True,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_library_agent.return_value.find_first.assert_called_once()
|
||||
where = mock_library_agent.return_value.find_first.call_args.kwargs["where"]
|
||||
assert where == {
|
||||
"agentGraphId": "agent1",
|
||||
"userId": "test-user",
|
||||
"isDeleted": False,
|
||||
"agentGraphVersion": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_in_library_allows_archived_library_agent(mocker):
|
||||
graph = mocker.Mock(id="graph-id")
|
||||
existing_version = mocker.Mock(version=1, is_active=True)
|
||||
graph_model = mocker.Mock()
|
||||
created_graph = mocker.Mock(id="graph-id", version=2, is_active=False)
|
||||
current_library_agent = mocker.Mock()
|
||||
updated_library_agent = mocker.Mock()
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.get_graph_all_versions",
|
||||
new=mocker.AsyncMock(return_value=[existing_version]),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.make_graph_model",
|
||||
return_value=graph_model,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.library.db.graph_db.create_graph",
|
||||
new=mocker.AsyncMock(return_value=created_graph),
|
||||
)
|
||||
mock_get_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.get_library_agent_by_graph_id",
|
||||
new=mocker.AsyncMock(return_value=current_library_agent),
|
||||
)
|
||||
mock_update_library_agent = mocker.patch(
|
||||
"backend.api.features.library.db.update_library_agent_version_and_settings",
|
||||
new=mocker.AsyncMock(return_value=updated_library_agent),
|
||||
)
|
||||
|
||||
result_graph, result_library_agent = await db.update_graph_in_library(
|
||||
graph,
|
||||
"test-user",
|
||||
)
|
||||
|
||||
assert result_graph is created_graph
|
||||
assert result_library_agent is updated_library_agent
|
||||
assert graph.version == 2
|
||||
graph_model.reassign_ids.assert_called_once_with(
|
||||
user_id="test-user", reassign_graph_id=False
|
||||
)
|
||||
mock_get_library_agent.assert_awaited_once_with(
|
||||
"test-user",
|
||||
"graph-id",
|
||||
include_archived=True,
|
||||
)
|
||||
mock_update_library_agent.assert_awaited_once_with("test-user", created_graph)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_library_agent_uses_upsert():
|
||||
"""create_library_agent should use upsert (not create) to handle duplicates."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-1"
|
||||
mock_graph.version = 1
|
||||
mock_graph.user_id = "user-1"
|
||||
mock_graph.nodes = []
|
||||
mock_graph.sub_graphs = []
|
||||
|
||||
mock_upserted = MagicMock(name="UpsertedLibraryAgent")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_tx():
|
||||
yield None
|
||||
|
||||
with (
|
||||
patch("backend.api.features.library.db.transaction", fake_tx),
|
||||
patch("prisma.models.LibraryAgent.prisma") as mock_prisma,
|
||||
patch(
|
||||
"backend.api.features.library.db.add_generated_agent_image",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.library.model.LibraryAgent.from_db",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=mock_upserted)
|
||||
|
||||
result = await db.create_library_agent(mock_graph, "user-1")
|
||||
|
||||
assert len(result) == 1
|
||||
upsert_call = mock_prisma.return_value.upsert.call_args
|
||||
assert upsert_call is not None
|
||||
# Verify the upsert where clause uses the composite unique key
|
||||
where = upsert_call.kwargs["where"]
|
||||
assert "userId_agentGraphId_agentGraphVersion" in where
|
||||
# Verify the upsert data has both create and update branches
|
||||
data = upsert_call.kwargs["data"]
|
||||
assert "create" in data
|
||||
assert "update" in data
|
||||
# Verify update branch restores soft-deleted/archived agents
|
||||
assert data["update"]["isDeleted"] is False
|
||||
assert data["update"]["isArchived"] is False
|
||||
|
||||
@@ -12,6 +12,7 @@ Tests cover:
|
||||
5. Complete OAuth flow end-to-end
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
@@ -58,14 +59,27 @@ async def test_user(server, test_user_id: str):
|
||||
|
||||
yield test_user_id
|
||||
|
||||
# Cleanup - delete in correct order due to foreign key constraints
|
||||
await PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthRefreshToken.prisma().delete_many(where={"userId": test_user_id})
|
||||
await PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
)
|
||||
await PrismaOAuthApplication.prisma().delete_many(where={"ownerId": test_user_id})
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
# Cleanup - delete in correct order due to foreign key constraints.
|
||||
# Wrap in try/except because the event loop or Prisma engine may already
|
||||
# be closed during session teardown on Python 3.12+.
|
||||
try:
|
||||
await asyncio.gather(
|
||||
PrismaOAuthAccessToken.prisma().delete_many(where={"userId": test_user_id}),
|
||||
PrismaOAuthRefreshToken.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
PrismaOAuthAuthorizationCode.prisma().delete_many(
|
||||
where={"userId": test_user_id}
|
||||
),
|
||||
)
|
||||
await asyncio.gather(
|
||||
PrismaOAuthApplication.prisma().delete_many(
|
||||
where={"ownerId": test_user_id}
|
||||
),
|
||||
PrismaUser.prisma().delete(where={"id": test_user_id}),
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
||||
@@ -5,16 +5,26 @@ Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin
|
||||
from typing import TYPE_CHECKING, Any, get_args, get_origin
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -154,6 +164,28 @@ class StoreAgentHandler(ContentHandler):
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_enabled_blocks() -> dict[str, AnyBlockSchema]:
|
||||
"""Return ``{block_id: block_instance}`` for all enabled, instantiable blocks.
|
||||
|
||||
Disabled blocks and blocks that fail to instantiate are silently skipped
|
||||
(with a warning log), so callers never need their own try/except loop.
|
||||
|
||||
Results are cached for the process lifetime via ``lru_cache`` because
|
||||
blocks are registered at import time and never change while running.
|
||||
"""
|
||||
enabled: dict[str, AnyBlockSchema] = {}
|
||||
for block_id, block_cls in get_blocks().items():
|
||||
try:
|
||||
instance = block_cls()
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping block {block_id}: init failed: {e}")
|
||||
continue
|
||||
if not instance.disabled:
|
||||
enabled[block_id] = instance
|
||||
return enabled
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@@ -163,16 +195,14 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
# to_thread keeps the first (heavy) call off the event loop. On
|
||||
# subsequent calls the lru_cache makes this a dict lookup, so the
|
||||
# thread-pool overhead is negligible compared to the DB queries below.
|
||||
enabled = await asyncio.to_thread(_get_enabled_blocks)
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
block_ids = list(enabled.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
@@ -187,52 +217,42 @@ class BlockHandler(ContentHandler):
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
# Convert to ContentItem — disabled filtering already done by
|
||||
# _get_enabled_blocks so batch_size won't be exhausted by disabled blocks.
|
||||
missing = ((bid, b) for bid, b in enabled.items() if bid not in existing_ids)
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
for block_id, block in itertools.islice(missing, batch_size):
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
if block_instance.disabled:
|
||||
continue
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if block_instance.description:
|
||||
parts.append(block_instance.description)
|
||||
if block_instance.categories:
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
if not block.name:
|
||||
logger.warning(
|
||||
f"Block {block_id} has no name — using block_id as fallback"
|
||||
)
|
||||
display_name = split_camelcase(block.name) if block.name else ""
|
||||
parts = []
|
||||
if display_name:
|
||||
parts.append(display_name)
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
if block.categories:
|
||||
parts.append(" ".join(str(cat.value) for cat in block.categories))
|
||||
|
||||
# Add input schema field descriptions
|
||||
block_input_fields = block_instance.input_schema.model_fields
|
||||
parts += [
|
||||
f"{field_name}: {field_info.description}"
|
||||
for field_name, field_info in block_input_fields.items()
|
||||
for field_name, field_info in block.input_schema.model_fields.items()
|
||||
if field_info.description
|
||||
]
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
categories_list = (
|
||||
[cat.value for cat in block_instance.categories]
|
||||
if block_instance.categories
|
||||
else []
|
||||
[cat.value for cat in block.categories] if block.categories else []
|
||||
)
|
||||
|
||||
# Extract provider names from credentials fields
|
||||
credentials_info = (
|
||||
block_instance.input_schema.get_credentials_fields_info()
|
||||
)
|
||||
credentials_info = block.input_schema.get_credentials_fields_info()
|
||||
is_integration = len(credentials_info) > 0
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
@@ -243,7 +263,7 @@ class BlockHandler(ContentHandler):
|
||||
# Check if block has LlmModel field in input schema
|
||||
has_llm_model_field = any(
|
||||
_contains_type(field.annotation, LlmModel)
|
||||
for field in block_instance.input_schema.model_fields.values()
|
||||
for field in block.input_schema.model_fields.values()
|
||||
)
|
||||
|
||||
items.append(
|
||||
@@ -252,13 +272,13 @@ class BlockHandler(ContentHandler):
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": block_instance.name,
|
||||
"name": display_name or block.name or block_id,
|
||||
"categories": categories_list,
|
||||
"providers": provider_names,
|
||||
"has_llm_model_field": has_llm_model_field,
|
||||
"is_integration": is_integration,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
user_id=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -269,22 +289,13 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Filter out disabled blocks - they're not indexed
|
||||
enabled_block_ids = [
|
||||
block_id
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if not block_cls().disabled
|
||||
]
|
||||
total_blocks = len(enabled_block_ids)
|
||||
enabled = await asyncio.to_thread(_get_enabled_blocks)
|
||||
total_blocks = len(enabled)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = enabled_block_ids
|
||||
block_ids = list(enabled.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
Tests for content handlers (blocks, store agents, documentation).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -15,15 +13,103 @@ from backend.api.features.store.content_handlers import (
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
_get_enabled_blocks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_block_cache():
|
||||
"""Clear the lru_cache on _get_enabled_blocks before each test."""
|
||||
_get_enabled_blocks.cache_clear()
|
||||
yield
|
||||
_get_enabled_blocks.cache_clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to build a mock block class that returns a pre-configured instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_class(
|
||||
*,
|
||||
name: str = "Block",
|
||||
description: str = "",
|
||||
disabled: bool = False,
|
||||
categories: list[MagicMock] | None = None,
|
||||
fields: dict[str, str] | None = None,
|
||||
raise_on_init: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
if raise_on_init is not None:
|
||||
cls.side_effect = raise_on_init
|
||||
return cls
|
||||
inst = MagicMock()
|
||||
inst.name = name
|
||||
inst.disabled = disabled
|
||||
inst.description = description
|
||||
inst.categories = categories or []
|
||||
field_mocks = {
|
||||
fname: MagicMock(description=fdesc) for fname, fdesc in (fields or {}).items()
|
||||
}
|
||||
inst.input_schema.model_fields = field_mocks
|
||||
inst.input_schema.get_credentials_fields_info.return_value = {}
|
||||
cls.return_value = inst
|
||||
return cls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_enabled_blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_enabled_blocks_filters_disabled():
|
||||
"""Disabled blocks are excluded."""
|
||||
blocks = {
|
||||
"enabled": _make_block_class(name="E", disabled=False),
|
||||
"disabled": _make_block_class(name="D", disabled=True),
|
||||
}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["enabled"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_skips_broken():
|
||||
"""Blocks that raise on init are skipped without crashing."""
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["good"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_cached():
|
||||
"""_get_enabled_blocks() calls get_blocks() only once across multiple calls."""
|
||||
blocks = {"b1": _make_block_class(name="B1")}
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
) as mock_get_blocks:
|
||||
result1 = _get_enabled_blocks()
|
||||
result2 = _get_enabled_blocks()
|
||||
assert result1 is result2
|
||||
mock_get_blocks.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StoreAgentHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
@@ -54,9 +140,7 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
@@ -70,74 +154,130 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BlockHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
async def test_block_handler_get_missing_items():
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_field = MagicMock()
|
||||
mock_field.description = "Math expression to evaluate"
|
||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
blocks = {
|
||||
"block-uuid-1": _make_block_class(
|
||||
name="CalculatorBlock",
|
||||
description="Performs calculations",
|
||||
categories=[MagicMock(value="MATH")],
|
||||
fields={"expression": "Math expression to evaluate"},
|
||||
),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
# CamelCase should be split in searchable text and metadata name
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Calculator Block"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
async def test_block_handler_get_missing_items_splits_camelcase():
|
||||
"""CamelCase block names are split for better search indexing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"ai-block": _make_block_class(name="AITextGeneratorBlock"),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert "AI Text Generator Block" in items[0].searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items_batch_size_zero():
|
||||
"""batch_size=0 returns an empty list; the DB is still queried to find missing IDs."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"b1": _make_block_class(name="B1")}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
) as mock_query:
|
||||
items = await handler.get_missing_items(batch_size=0)
|
||||
assert items == []
|
||||
# DB query is still issued to learn which blocks lack embeddings;
|
||||
# the empty result comes from itertools.islice limiting to 0 items.
|
||||
mock_query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_disabled_dont_exhaust_batch():
|
||||
"""Disabled blocks don't consume batch budget, so enabled blocks get indexed."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# 5 disabled + 3 enabled, batch_size=2
|
||||
blocks = {
|
||||
**{
|
||||
f"dis-{i}": _make_block_class(name=f"D{i}", disabled=True) for i in range(5)
|
||||
},
|
||||
**{f"en-{i}": _make_block_class(name=f"E{i}") for i in range(3)},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=2)
|
||||
|
||||
assert len(items) == 2
|
||||
assert all(item.content_id.startswith("en-") for item in items)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats():
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks - each block class returns an instance with disabled=False
|
||||
def make_mock_block_class():
|
||||
mock_class = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.disabled = False
|
||||
mock_class.return_value = mock_instance
|
||||
return mock_class
|
||||
|
||||
mock_blocks = {
|
||||
"block-1": make_mock_block_class(),
|
||||
"block-2": make_mock_block_class(),
|
||||
"block-3": make_mock_block_class(),
|
||||
blocks = {
|
||||
"block-1": _make_block_class(name="B1"),
|
||||
"block-2": _make_block_class(name="B2"),
|
||||
"block-3": _make_block_class(name="B3"),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
@@ -150,21 +290,123 @@ async def test_block_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats_skips_broken():
|
||||
"""get_stats skips broken blocks instead of crashing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 1 # only the good block
|
||||
assert stats["with_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_none_name():
|
||||
"""When block.name is None the fallback display name logic is used."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"none-name-block": _make_block_class(
|
||||
name="placeholder", # will be overridden to None below
|
||||
description="A block with no name",
|
||||
),
|
||||
}
|
||||
# Override the name to None after construction so _make_block_class
|
||||
# doesn't interfere with the mock wiring.
|
||||
blocks["none-name-block"].return_value.name = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
# display_name should be "" because block.name is None
|
||||
# searchable_text should still contain the description
|
||||
assert "A block with no name" in items[0].searchable_text
|
||||
# metadata["name"] falls back to block_id when both display_name
|
||||
# and block.name are falsy, ensuring it is always a non-empty string.
|
||||
assert items[0].metadata["name"] == "none-name-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"block-minimal": _make_block_class(name="Minimal Block")}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good-block": _make_block_class(name="Good Block", description="Works fine"),
|
||||
"bad-block": _make_block_class(raise_on_init=Exception("Instantiation failed")),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentationHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
@@ -173,7 +415,6 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
@@ -184,7 +425,6 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
@@ -197,14 +437,12 @@ async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
@@ -224,13 +462,11 @@ async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
@@ -242,7 +478,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
@@ -254,7 +489,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
@@ -268,7 +502,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
@@ -282,21 +515,39 @@ async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
@@ -307,88 +558,3 @@ async def test_content_handlers_registry():
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with empty values (all attributes exist but are falsy)
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.description = ""
|
||||
mock_block_instance.categories = set()
|
||||
mock_block_instance.input_schema.model_fields = {}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_instance.input_schema.model_fields = {}
|
||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
@@ -9,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.graph import (
|
||||
GraphModel,
|
||||
GraphModelWithoutNodes,
|
||||
@@ -104,7 +104,8 @@ async def get_store_agents(
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
# Fall through to direct DB search if hybrid returned nothing
|
||||
if search_used_hybrid and agents:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
@@ -130,52 +131,20 @@ async def get_store_agents(
|
||||
)
|
||||
continue
|
||||
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
if not search_used_hybrid or not agents:
|
||||
# Fallback path: direct DB query with optional tsvector search.
|
||||
# This mirrors the original pre-hybrid-search implementation.
|
||||
store_agents, total = await _fallback_store_agent_search(
|
||||
search_query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in db_agents:
|
||||
try:
|
||||
store_agents.append(store_model.StoreAgent.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from db: {e}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return store_model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
@@ -195,6 +164,126 @@ async def get_store_agents(
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def _fallback_store_agent_search(
|
||||
*,
|
||||
search_query: str | None,
|
||||
featured: bool,
|
||||
creators: list[str] | None,
|
||||
category: str | None,
|
||||
sorted_by: StoreAgentsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> tuple[list[store_model.StoreAgent], int]:
|
||||
"""Direct DB search fallback when hybrid search is unavailable or empty.
|
||||
|
||||
Uses ad-hoc to_tsvector/plainto_tsquery with ts_rank_cd for text search,
|
||||
matching the quality of the original pre-hybrid-search implementation.
|
||||
Falls back to simple listing when no search query is provided.
|
||||
"""
|
||||
if not search_query:
|
||||
# No search query — use Prisma for simple filtered listing
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == StoreAgentsSortOptions.RATING:
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.RUNS:
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.NAME:
|
||||
order_by.append({"agent_name": "asc"})
|
||||
elif sorted_by == StoreAgentsSortOptions.UPDATED_AT:
|
||||
order_by.append({"updated_at": "desc"})
|
||||
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
return [store_model.StoreAgent.from_db(a) for a in db_agents], total
|
||||
|
||||
# Text search using ad-hoc tsvector on StoreAgent view fields
|
||||
params: list[Any] = [search_query]
|
||||
filters = ["sa.is_available = true"]
|
||||
param_idx = 2
|
||||
|
||||
if featured:
|
||||
filters.append("sa.featured = true")
|
||||
if creators:
|
||||
params.append(creators)
|
||||
filters.append(f"sa.creator_username = ANY(${param_idx})")
|
||||
param_idx += 1
|
||||
if category:
|
||||
params.append(category)
|
||||
filters.append(f"${param_idx} = ANY(sa.categories)")
|
||||
param_idx += 1
|
||||
|
||||
where_sql = " AND ".join(filters)
|
||||
|
||||
params.extend([page_size, (page - 1) * page_size])
|
||||
limit_param = f"${param_idx}"
|
||||
param_idx += 1
|
||||
offset_param = f"${param_idx}"
|
||||
|
||||
sql = f"""
|
||||
WITH ranked AS (
|
||||
SELECT sa.*,
|
||||
ts_rank_cd(
|
||||
to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
),
|
||||
plainto_tsquery('english', $1)
|
||||
) AS rank,
|
||||
COUNT(*) OVER () AS total_count
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_sql}
|
||||
AND to_tsvector('english',
|
||||
COALESCE(sa.agent_name, '') || ' ' ||
|
||||
COALESCE(sa.sub_heading, '') || ' ' ||
|
||||
COALESCE(sa.description, '')
|
||||
) @@ plainto_tsquery('english', $1)
|
||||
)
|
||||
SELECT * FROM ranked
|
||||
ORDER BY rank DESC
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
store_agents = []
|
||||
for row in results:
|
||||
try:
|
||||
store_agents.append(
|
||||
store_model.StoreAgent(
|
||||
slug=row["slug"],
|
||||
agent_name=row["agent_name"],
|
||||
agent_image=row["agent_image"][0] if row["agent_image"] else "",
|
||||
creator=row["creator_username"] or "Needs Profile",
|
||||
creator_avatar=row["creator_avatar"] or "",
|
||||
sub_heading=row["sub_heading"],
|
||||
description=row["description"],
|
||||
runs=row["runs"],
|
||||
rating=row["rating"],
|
||||
agent_graph_id=row.get("graph_id", ""),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing StoreAgent from fallback search: {e}")
|
||||
continue
|
||||
|
||||
return store_agents, total
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
@@ -302,6 +391,11 @@ async def get_available_graph(
|
||||
async def get_store_agent_by_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details from the StoreAgent view (APPROVED agents only).
|
||||
|
||||
See also: `get_store_agent_details_as_admin()` which bypasses the
|
||||
APPROVED-only StoreAgent view for admin preview of pending submissions.
|
||||
"""
|
||||
logger.debug(f"Getting store agent details for {store_listing_version_id}")
|
||||
|
||||
try:
|
||||
@@ -322,6 +416,57 @@ async def get_store_agent_by_version_id(
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
|
||||
|
||||
async def get_store_agent_details_as_admin(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get agent details for admin preview, bypassing the APPROVED-only
|
||||
StoreAgent view. Queries StoreListingVersion directly so pending
|
||||
submissions are visible."""
|
||||
slv = await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": {"include": {"CreatorProfile": True}},
|
||||
},
|
||||
)
|
||||
if not slv or not slv.StoreListing:
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
listing = slv.StoreListing
|
||||
# CreatorProfile is a required FK relation — should always exist.
|
||||
# If it's None, the DB is in a bad state.
|
||||
profile = listing.CreatorProfile
|
||||
if not profile:
|
||||
raise DatabaseError(
|
||||
f"StoreListing {listing.id} has no CreatorProfile — FK violated"
|
||||
)
|
||||
|
||||
return store_model.StoreAgentDetails(
|
||||
store_listing_version_id=slv.id,
|
||||
slug=listing.slug,
|
||||
agent_name=slv.name,
|
||||
agent_video=slv.videoUrl or "",
|
||||
agent_output_demo=slv.agentOutputDemoUrl or "",
|
||||
agent_image=slv.imageUrls,
|
||||
creator=profile.username,
|
||||
creator_avatar=profile.avatarUrl or "",
|
||||
sub_heading=slv.subHeading,
|
||||
description=slv.description,
|
||||
instructions=slv.instructions,
|
||||
categories=slv.categories,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
versions=[str(slv.version)],
|
||||
graph_id=slv.agentGraphId,
|
||||
graph_versions=[str(slv.agentGraphVersion)],
|
||||
last_updated=slv.updatedAt,
|
||||
recommended_schedule_cron=slv.recommendedScheduleCron,
|
||||
active_version_id=listing.activeVersionId or slv.id,
|
||||
has_approved_version=listing.hasApprovedVersion,
|
||||
)
|
||||
|
||||
|
||||
class StoreCreatorsSortOptions(Enum):
|
||||
# NOTE: values correspond 1:1 to columns of the Creator view
|
||||
AGENT_RATING = "agent_rating"
|
||||
@@ -1139,16 +1284,21 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
# Generate embedding for approved listing (best-effort)
|
||||
try:
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=submission.name,
|
||||
description=submission.description,
|
||||
sub_heading=submission.subHeading,
|
||||
categories=submission.categories,
|
||||
tx=tx,
|
||||
)
|
||||
except Exception as emb_err:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for listing "
|
||||
f"{store_listing_version_id}: {emb_err}"
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": submission.storeListingId},
|
||||
|
||||
@@ -15,6 +15,7 @@ from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.blocks import get_blocks
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
@@ -662,8 +663,6 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
)
|
||||
current_ids = {row["id"] for row in valid_agents}
|
||||
elif content_type == ContentType.BLOCK:
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
# Use DocumentationHandler to get section-based content IDs
|
||||
|
||||
@@ -31,12 +31,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
|
||||
"""Tokenize text for BM25."""
|
||||
if not text:
|
||||
return []
|
||||
# Lowercase and split on non-alphanumeric characters
|
||||
tokens = re.findall(r"\b\w+\b", text.lower())
|
||||
return tokens
|
||||
return re.findall(r"\b\w+\b", text.lower())
|
||||
|
||||
|
||||
def bm25_rerank(
|
||||
|
||||
@@ -14,9 +14,27 @@ from backend.api.features.store.hybrid_search import (
|
||||
HybridSearchWeights,
|
||||
UnifiedSearchWeights,
|
||||
hybrid_search,
|
||||
tokenize,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tokenize (BM25)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", ["aitextgeneratorblock"]),
|
||||
("hello world", ["hello", "world"]),
|
||||
("", []),
|
||||
("HTTPRequest", ["httprequest"]),
|
||||
],
|
||||
)
|
||||
def test_tokenize(input_text: str, expected: list[str]):
|
||||
assert tokenize(input_text) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
|
||||
import autogpt_libs.auth
|
||||
@@ -259,21 +258,18 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
) -> fastapi.responses.Response:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return fastapi.responses.Response(
|
||||
content=backend.util.json.dumps(graph_data),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{file_name}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Backward-compatibility shim — ``split_camelcase`` now lives in backend.util.text."""
|
||||
|
||||
from backend.util.text import split_camelcase # noqa: F401
|
||||
|
||||
__all__ = ["split_camelcase"]
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Tests for split_camelcase (now in backend.util.text)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.text import split_camelcase
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# split_camelcase
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", "AI Text Generator Block"),
|
||||
("HTTPRequestBlock", "HTTP Request Block"),
|
||||
("simpleWord", "simple Word"),
|
||||
("already spaced", "already spaced"),
|
||||
("XMLParser", "XML Parser"),
|
||||
("getHTTPResponse", "get HTTP Response"),
|
||||
("Block", "Block"),
|
||||
("", ""),
|
||||
("OAuth2Block", "OAuth2 Block"),
|
||||
("IOError", "IO Error"),
|
||||
("getHTTPSResponse", "get HTTPS Response"),
|
||||
# Known limitation: single-letter uppercase prefixes are NOT split.
|
||||
# "ABlock" stays "ABlock" because the algorithm requires the left
|
||||
# part of an uppercase run to retain at least 2 uppercase chars.
|
||||
("ABlock", "ABlock"),
|
||||
# Digit-to-uppercase transitions
|
||||
("Base64Encoder", "Base64 Encoder"),
|
||||
("UTF8Decoder", "UTF8 Decoder"),
|
||||
# Pure digits — no camelCase boundaries to split
|
||||
("123", "123"),
|
||||
# Known limitation: single-letter uppercase segments after digits
|
||||
# are not split from the following word. "3D" is only 1 uppercase
|
||||
# char so the uppercase-run rule cannot fire, producing "3 DRenderer"
|
||||
# rather than the ideal "3D Renderer".
|
||||
("3DRenderer", "3 DRenderer"),
|
||||
# Exception list — compound terms that should stay together
|
||||
("YouTubeBlock", "YouTube Block"),
|
||||
("OpenAIBlock", "OpenAI Block"),
|
||||
("AutoGPTAgent", "AutoGPT Agent"),
|
||||
("GitHubIntegration", "GitHub Integration"),
|
||||
("LinkedInBlock", "LinkedIn Block"),
|
||||
],
|
||||
)
|
||||
def test_split_camelcase(input_text: str, expected: str):
|
||||
assert split_camelcase(input_text) == expected
|
||||
@@ -592,6 +592,11 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
"""Configure auto top-up settings and perform an immediate top-up if needed.
|
||||
|
||||
Raises HTTPException(422) if the request parameters are invalid or if
|
||||
the credit top-up fails.
|
||||
"""
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
@@ -606,10 +611,20 @@ async def configure_user_auto_top_up(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
try:
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
except ValueError as e:
|
||||
known_messages = (
|
||||
"must not be negative",
|
||||
"already exists for user",
|
||||
"No payment method found",
|
||||
)
|
||||
if any(msg in str(e) for msg in known_messages):
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
@@ -965,14 +980,16 @@ async def execute_graph(
|
||||
source: Annotated[GraphExecutionSource | None, Body(embed=True)] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
dry_run: Annotated[bool, Body(embed=True)] = False,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
if not dry_run:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
@@ -982,6 +999,7 @@ async def execute_graph(
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
|
||||
@@ -188,6 +188,7 @@ async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
overwrite: bool = Query(default=False),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
@@ -248,7 +249,9 @@ async def upload_file(
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -117,6 +118,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Register managed credential providers (e.g. AgentMail)
|
||||
from backend.integrations.managed_providers import register_all
|
||||
|
||||
register_all()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
@@ -210,13 +216,22 @@ instrument_fastapi(
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
if log_error:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
if status_code >= 500:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s %s failed with %d: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
exc,
|
||||
)
|
||||
|
||||
hint = (
|
||||
"Adjust the request and retry."
|
||||
@@ -266,12 +281,10 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
@@ -311,6 +324,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.rate_limit_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
@@ -521,8 +539,11 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from .features.integrations.router import create_credentials, get_credential
|
||||
):
|
||||
from backend.api.features.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Shared configuration for all AgentMail blocks.
|
||||
"""
|
||||
|
||||
from agentmail import AsyncAgentMail
|
||||
|
||||
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
|
||||
|
||||
agent_mail = (
|
||||
ProviderBuilder("agent_mail")
|
||||
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="agent_mail",
|
||||
title="Mock AgentMail API Key",
|
||||
api_key=SecretStr("mock-agentmail-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def _client(credentials: APIKeyCredentials) -> AsyncAgentMail:
|
||||
"""Create an AsyncAgentMail client from credentials."""
|
||||
return AsyncAgentMail(api_key=credentials.api_key.get_secret_value())
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
AgentMail Attachment blocks — download file attachments from messages and threads.
|
||||
|
||||
Attachments are files associated with messages (PDFs, CSVs, images, etc.).
|
||||
To send attachments, include them in the attachments parameter when using
|
||||
AgentMailSendMessageBlock or AgentMailReplyToMessageBlock.
|
||||
|
||||
To download, first get the attachment_id from a message's attachments array,
|
||||
then use these blocks to retrieve the file content as base64.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailGetMessageAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a specific email message.
|
||||
|
||||
Retrieves the raw file content and returns it as base64-encoded data.
|
||||
First get the attachment_id from a message object's attachments array,
|
||||
then use this block to download the file.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID containing the attachment"
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from the message's attachments array)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a283ffc4-8087-4c3d-9135-8f26b86742ec",
|
||||
description="Download a file attachment from an email message. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
message_id=message_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
message_id=input_data.message_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetThreadAttachmentBlock(Block):
|
||||
"""
|
||||
Download a file attachment from a conversation thread.
|
||||
|
||||
Same as GetMessageAttachment but looks up by thread ID instead of
|
||||
message ID. Useful when you know the thread but not the specific
|
||||
message containing the attachment.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID containing the attachment")
|
||||
attachment_id: str = SchemaField(
|
||||
description="Attachment ID to download (from a message's attachments array within the thread)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
content_base64: str = SchemaField(
|
||||
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
|
||||
)
|
||||
attachment_id: str = SchemaField(
|
||||
description="The attachment ID that was downloaded"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="06b6a4c4-9d71-4992-9e9c-cf3b352763b5",
|
||||
description="Download a file attachment from a conversation thread. Returns base64-encoded file content.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
"attachment_id": "test-attach",
|
||||
},
|
||||
test_output=[
|
||||
("content_base64", "dGVzdA=="),
|
||||
("attachment_id", "test-attach"),
|
||||
],
|
||||
test_mock={
|
||||
"get_attachment": lambda *a, **kw: b"test",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_attachment(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
thread_id: str,
|
||||
attachment_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get_attachment(
|
||||
inbox_id=inbox_id,
|
||||
thread_id=thread_id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_attachment(
|
||||
credentials=credentials,
|
||||
inbox_id=input_data.inbox_id,
|
||||
thread_id=input_data.thread_id,
|
||||
attachment_id=input_data.attachment_id,
|
||||
)
|
||||
if isinstance(data, bytes):
|
||||
encoded = base64.b64encode(data).decode()
|
||||
elif isinstance(data, str):
|
||||
encoded = base64.b64encode(data.encode("utf-8")).decode()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unexpected attachment data type: {type(data).__name__}"
|
||||
)
|
||||
|
||||
yield "content_base64", encoded
|
||||
yield "attachment_id", input_data.attachment_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
678
autogpt_platform/backend/backend/blocks/agent_mail/drafts.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
AgentMail Draft blocks — create, get, list, update, send, and delete drafts.
|
||||
|
||||
A Draft is an unsent message that can be reviewed, edited, and sent later.
|
||||
Drafts enable human-in-the-loop review, scheduled sending (via send_at),
|
||||
and complex multi-step email composition workflows.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateDraftBlock(Block):
|
||||
"""
|
||||
Create a draft email in an AgentMail inbox for review or scheduled sending.
|
||||
|
||||
Drafts let agents prepare emails without sending immediately. Use send_at
|
||||
to schedule automatic sending at a future time (ISO 8601 format).
|
||||
Scheduled drafts are auto-labeled 'scheduled' and can be cancelled by
|
||||
deleting the draft.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to create the draft in"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line", default="")
|
||||
text: str = SchemaField(description="Plain text body of the draft", default="")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the draft", default="", advanced=True
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
in_reply_to: str = SchemaField(
|
||||
description="Message ID this draft replies to, for threading follow-up drafts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Schedule automatic sending at this ISO 8601 datetime (e.g. '2025-01-15T09:00:00Z'). Leave empty for manual send.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(
|
||||
description="Unique identifier of the created draft"
|
||||
)
|
||||
send_status: str = SchemaField(
|
||||
description="'scheduled' if send_at was set, empty otherwise. Values: scheduled, sending, failed.",
|
||||
default="",
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete draft object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="25ac9086-69fd-48b8-b910-9dbe04b8f3bd",
|
||||
description="Create a draft email for review or scheduled sending. Use send_at for automatic future delivery.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "mock-draft-id"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "mock-draft-id",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "mock-draft-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_draft(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.create(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.in_reply_to:
|
||||
params["in_reply_to"] = input_data.in_reply_to
|
||||
if input_data.send_at:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.create_draft(credentials, input_data.inbox_id, **params)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetDraftBlock(Block):
|
||||
"""
|
||||
Retrieve a specific draft from an AgentMail inbox.
|
||||
|
||||
Returns the draft contents including recipients, subject, body, and
|
||||
scheduled send status. Use this to review a draft before approving it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="Unique identifier of the draft")
|
||||
subject: str = SchemaField(description="Draft subject line", default="")
|
||||
send_status: str = SchemaField(
|
||||
description="Scheduled send status: 'scheduled', 'sending', 'failed', or empty",
|
||||
default="",
|
||||
)
|
||||
send_at: str = SchemaField(
|
||||
description="Scheduled send time (ISO 8601) if set", default=""
|
||||
)
|
||||
result: dict = SchemaField(description="Complete draft object with all fields")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e57780d-dc25-43d4-a0f4-1f02877b09fb",
|
||||
description="Retrieve a draft email to review its contents, recipients, and scheduled send status.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("subject", ""),
|
||||
("send_status", ""),
|
||||
("send_at", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"subject": "",
|
||||
"send_status": "",
|
||||
"send_at": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.get(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
draft = await self.get_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "subject", draft.subject or ""
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "send_at", draft.send_at or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Use labels=['scheduled'] to find all drafts queued for future sending.
|
||||
Useful for building approval dashboards or monitoring pending outreach.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list drafts from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Filter drafts by labels (e.g. ['scheduled'] for pending sends)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects with subject, recipients, send_status, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e84883b7-7c39-4c5c-88e8-0a72b078ea63",
|
||||
description="List drafts in an AgentMail inbox. Filter by labels=['scheduled'] to find pending sends.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_drafts(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_drafts(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateDraftBlock(Block):
|
||||
"""
|
||||
Update an existing draft's content, recipients, or scheduled send time.
|
||||
|
||||
Use this to reschedule a draft (change send_at), modify recipients,
|
||||
or edit the subject/body before sending. To cancel a scheduled send,
|
||||
delete the draft instead.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to update")
|
||||
to: Optional[list[str]] = SchemaField(
|
||||
description="Updated recipient email addresses (replaces existing list). Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
subject: Optional[str] = SchemaField(
|
||||
description="Updated subject line. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
text: Optional[str] = SchemaField(
|
||||
description="Updated plain text body. Omit to keep current value.",
|
||||
default=None,
|
||||
)
|
||||
html: Optional[str] = SchemaField(
|
||||
description="Updated HTML body. Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
send_at: Optional[str] = SchemaField(
|
||||
description="Reschedule: new ISO 8601 send time (e.g. '2025-01-20T14:00:00Z'). Omit to keep current value.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
draft_id: str = SchemaField(description="The updated draft ID")
|
||||
send_status: str = SchemaField(description="Updated send status", default="")
|
||||
result: dict = SchemaField(description="Complete updated draft object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="351f6e51-695a-421a-9032-46a587b10336",
|
||||
description="Update a draft's content, recipients, or scheduled send time. Use to reschedule or edit before sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("draft_id", "test-draft"),
|
||||
("send_status", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_draft": lambda *a, **kw: type(
|
||||
"Draft",
|
||||
(),
|
||||
{
|
||||
"draft_id": "test-draft",
|
||||
"send_status": "",
|
||||
"model_dump": lambda self: {"draft_id": "test-draft"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.update(
|
||||
inbox_id=inbox_id, draft_id=draft_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.to is not None:
|
||||
params["to"] = input_data.to
|
||||
if input_data.subject is not None:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text is not None:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html is not None:
|
||||
params["html"] = input_data.html
|
||||
if input_data.send_at is not None:
|
||||
params["send_at"] = input_data.send_at
|
||||
|
||||
draft = await self.update_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id, **params
|
||||
)
|
||||
result = draft.model_dump()
|
||||
|
||||
yield "draft_id", draft.draft_id
|
||||
yield "send_status", draft.send_status or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailSendDraftBlock(Block):
|
||||
"""
|
||||
Send a draft immediately, converting it into a delivered message.
|
||||
|
||||
The draft is deleted after successful sending and becomes a regular
|
||||
message with a message_id. Use this for human-in-the-loop approval
|
||||
workflows: agent creates draft, human reviews, then this block sends it.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(description="Draft ID to send now")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID of the now-sent email (draft is deleted)"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID the sent message belongs to"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete sent message object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37c39e83-475d-4b3d-843a-d923d001b85a",
|
||||
description="Send a draft immediately, converting it into a delivered message. The draft is deleted after sending.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_draft": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-msg-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.drafts.send(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.send_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteDraftBlock(Block):
|
||||
"""
|
||||
Delete a draft from an AgentMail inbox. Also cancels any scheduled send.
|
||||
|
||||
If the draft was scheduled with send_at, deleting it cancels the
|
||||
scheduled delivery. This is the way to cancel a scheduled email.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the draft belongs to"
|
||||
)
|
||||
draft_id: str = SchemaField(
|
||||
description="Draft ID to delete (also cancels scheduled sends)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the draft was successfully deleted/cancelled"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9023eb99-3e2f-4def-808b-d9c584b3d9e7",
|
||||
description="Delete a draft or cancel a scheduled email. Removes the draft permanently.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"draft_id": "test-draft",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_draft": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_draft(
|
||||
credentials: APIKeyCredentials, inbox_id: str, draft_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.drafts.delete(inbox_id=inbox_id, draft_id=draft_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_draft(
|
||||
credentials, input_data.inbox_id, input_data.draft_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across every inbox in your organization.
|
||||
|
||||
Returns drafts from all inboxes in one query. Perfect for building
|
||||
a central approval dashboard where a human supervisor can review
|
||||
and approve any draft created by any agent.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed7558ae-3a07-45f5-af55-a25fe88c9971",
|
||||
description="List all drafts across every inbox in your organization. Use for central approval dashboards.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_drafts(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.drafts.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_org_drafts(credentials, **params)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
414
autogpt_platform/backend/backend/blocks/agent_mail/inbox.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
AgentMail Inbox blocks — create, get, list, update, and delete inboxes.
|
||||
|
||||
An Inbox is a fully programmable email account for AI agents. Each inbox gets
|
||||
a unique email address and can send, receive, and manage emails via the
|
||||
AgentMail API. You can create thousands of inboxes on demand.
|
||||
"""
|
||||
|
||||
from agentmail.inboxes.types import CreateInboxRequest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreateInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox for an AI agent via AgentMail.
|
||||
|
||||
Each inbox gets a unique email address (e.g. username@agentmail.to).
|
||||
If username and domain are not provided, AgentMail auto-generates them.
|
||||
Use custom domains by specifying the domain field.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support' for support@domain.com). Leave empty to auto-generate.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field of sent emails (e.g. 'Support Agent')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier for the created inbox (also the email address)"
|
||||
)
|
||||
email_address: str = SchemaField(
|
||||
description="Full email address of the inbox (e.g. support@agentmail.to)"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7a8ac219-c6ec-4eec-a828-81af283ce04c",
|
||||
description="Create a new email inbox for an AI agent via AgentMail. Each inbox gets a unique address and can send/receive emails.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_inbox(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.create(request=CreateInboxRequest(**params))
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_inbox(credentials, **params)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing AgentMail inbox by its ID or email address.
|
||||
|
||||
Returns the inbox metadata including email address, display name, and
|
||||
configuration. Use this to check if an inbox exists or get its properties.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to look up (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="Unique identifier of the inbox")
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field", default=""
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b858f62b-6c12-4736-aaf2-dbc5a9281320",
|
||||
description="Retrieve details of an existing AgentMail inbox including its email address, display name, and configuration.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("email_address", "test-inbox"),
|
||||
("display_name", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.get(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.get_inbox(credentials, input_data.inbox_id)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "display_name", inbox.display_name or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListInboxesBlock(Block):
|
||||
"""
|
||||
List all email inboxes in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all inboxes with their metadata.
|
||||
Use page_token for pagination when you have many inboxes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page of results",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects, each containing inbox_id, email_address, display_name, etc."
|
||||
)
|
||||
count: int = SchemaField(
|
||||
description="Total number of inboxes in your organization"
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token to pass as page_token to get the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cfd84a06-2121-4cef-8d14-8badf52d22f0",
|
||||
description="List all email inboxes in your AgentMail organization with pagination support.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_inboxes(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_inboxes(credentials, **params)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", (c if (c := response.count) is not None else len(inboxes))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateInboxBlock(Block):
|
||||
"""
|
||||
Update the display name of an existing AgentMail inbox.
|
||||
|
||||
Changes the friendly name shown in the 'From' field when emails are sent
|
||||
from this inbox. The email address itself cannot be changed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to update (e.g. 'support@agentmail.to')"
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="New display name for the inbox (e.g. 'Customer Support Bot')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(description="The updated inbox ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59b49f59-a6d1-4203-94c0-3908adac50b6",
|
||||
description="Update the display name of an AgentMail inbox. Changes the 'From' name shown when emails are sent.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"display_name": "Updated",
|
||||
},
|
||||
test_output=[
|
||||
("inbox_id", "test-inbox"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "test-inbox",
|
||||
"model_dump": lambda self: {"inbox_id": "test-inbox"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_inbox(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.update(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
inbox = await self.update_inbox(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
display_name=input_data.display_name,
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxBlock(Block):
|
||||
"""
|
||||
Permanently delete an AgentMail inbox and all its data.
|
||||
|
||||
This removes the inbox, all its messages, threads, and drafts.
|
||||
This action cannot be undone. The email address will no longer
|
||||
receive or send emails.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to permanently delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the inbox was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ade970ae-8428-4a7b-9278-b52054dbf535",
|
||||
description="Permanently delete an AgentMail inbox and all its messages, threads, and drafts. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_inbox": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_inbox(credentials: APIKeyCredentials, inbox_id: str):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.delete(inbox_id=inbox_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_inbox(credentials, input_data.inbox_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
384
autogpt_platform/backend/backend/blocks/agent_mail/lists.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
AgentMail List blocks — manage allow/block lists for email filtering.
|
||||
|
||||
Lists let you control which email addresses and domains your agents can
|
||||
send to or receive from. There are four list types based on two dimensions:
|
||||
direction (send/receive) and type (allow/block).
|
||||
|
||||
- receive + allow: Only accept emails from these addresses/domains
|
||||
- receive + block: Reject emails from these addresses/domains
|
||||
- send + allow: Only send emails to these addresses/domains
|
||||
- send + block: Prevent sending emails to these addresses/domains
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class ListDirection(str, Enum):
|
||||
SEND = "send"
|
||||
RECEIVE = "receive"
|
||||
|
||||
|
||||
class ListType(str, Enum):
|
||||
ALLOW = "allow"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class AgentMailListEntriesBlock(Block):
|
||||
"""
|
||||
List all entries in an AgentMail allow/block list.
|
||||
|
||||
Retrieves email addresses and domains that are currently allowed
|
||||
or blocked for sending or receiving. Use direction and list_type
|
||||
to select which of the four lists to query.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' to filter outgoing emails, 'receive' to filter incoming emails"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist (only permit these), 'block' for blacklist (reject these)"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to return per page",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entries: list[dict] = SchemaField(
|
||||
description="List of entries, each with an email address or domain"
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="01489100-35da-45aa-8a01-9540ba0e9a21",
|
||||
description="List all entries in an AgentMail allow/block list. Choose send/receive direction and allow/block type.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
},
|
||||
test_output=[
|
||||
("entries", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_entries": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"entries": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_entries(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.list(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_entries(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
entries = [e.model_dump() for e in response.entries]
|
||||
|
||||
yield "entries", entries
|
||||
yield "count", (c if (c := response.count) is not None else len(entries))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreateListEntryBlock(Block):
|
||||
"""
|
||||
Add an email address or domain to an AgentMail allow/block list.
|
||||
|
||||
Entries can be full email addresses (e.g. 'partner@example.com') or
|
||||
entire domains (e.g. 'example.com'). For block lists, you can optionally
|
||||
provide a reason (e.g. 'spam', 'competitor').
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing email rules, 'receive' for incoming email rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' to whitelist, 'block' to blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address (user@example.com) or domain (example.com) to add"
|
||||
)
|
||||
reason: str = SchemaField(
|
||||
description="Reason for blocking (only used with block lists, e.g. 'spam', 'competitor')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was added"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6650a0a-b113-40cf-8243-ff20f684f9b8",
|
||||
description="Add an email address or domain to an allow/block list. Block spam senders or whitelist trusted domains.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.create(direction, list_type, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"entry": input_data.entry}
|
||||
if input_data.reason and input_data.list_type == ListType.BLOCK:
|
||||
params["reason"] = input_data.reason
|
||||
|
||||
result = await self.create_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
**params,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetListEntryBlock(Block):
|
||||
"""
|
||||
Check if an email address or domain exists in an AgentMail allow/block list.
|
||||
|
||||
Returns the entry details if found. Use this to verify whether a specific
|
||||
address or domain is currently allowed or blocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(description="Email address or domain to look up")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
entry: str = SchemaField(
|
||||
description="The email address or domain that was found"
|
||||
)
|
||||
result: dict = SchemaField(description="Complete entry object with metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fb117058-ab27-40d1-9231-eb1dd526fc7a",
|
||||
description="Check if an email address or domain is in an allow/block list. Verify filtering rules.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[
|
||||
("entry", "spam@example.com"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_entry": lambda *a, **kw: type(
|
||||
"Entry",
|
||||
(),
|
||||
{
|
||||
"model_dump": lambda self: {"entry": "spam@example.com"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.lists.get(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
result_dict = result.model_dump()
|
||||
|
||||
yield "entry", input_data.entry
|
||||
yield "result", result_dict
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteListEntryBlock(Block):
|
||||
"""
|
||||
Remove an email address or domain from an AgentMail allow/block list.
|
||||
|
||||
After removal, the address/domain will no longer be filtered by this list.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
direction: ListDirection = SchemaField(
|
||||
description="'send' for outgoing rules, 'receive' for incoming rules"
|
||||
)
|
||||
list_type: ListType = SchemaField(
|
||||
description="'allow' for whitelist, 'block' for blacklist"
|
||||
)
|
||||
entry: str = SchemaField(
|
||||
description="Email address or domain to remove from the list"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the entry was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b8d57f1-1c9e-470f-a70b-5991c80fad5f",
|
||||
description="Remove an email address or domain from an allow/block list to stop filtering it.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"direction": "receive",
|
||||
"list_type": "block",
|
||||
"entry": "spam@example.com",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_entry": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_entry(
|
||||
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.lists.delete(direction, list_type, entry=entry)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_entry(
|
||||
credentials,
|
||||
input_data.direction.value,
|
||||
input_data.list_type.value,
|
||||
input_data.entry,
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
695
autogpt_platform/backend/backend/blocks/agent_mail/messages.py
Normal file
@@ -0,0 +1,695 @@
|
||||
"""
|
||||
AgentMail Message blocks — send, list, get, reply, forward, and update messages.
|
||||
|
||||
A Message is an individual email within a Thread. Agents can send new messages
|
||||
(which create threads), reply to existing messages, forward them, and manage
|
||||
labels for state tracking (e.g. read/unread, campaign tags).
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailSendMessageBlock(Block):
|
||||
"""
|
||||
Send a new email from an AgentMail inbox, automatically creating a new thread.
|
||||
|
||||
Supports plain text and HTML bodies, CC/BCC recipients, and labels for
|
||||
organizing messages (e.g. campaign tracking, state management).
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send from (e.g. 'agent@agentmail.to')"
|
||||
)
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses (e.g. ['user@example.com'])"
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Plain text body of the email. Always provide this as a fallback for email clients that don't render HTML."
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the email. Embed CSS in a <style> tag for best compatibility across email clients.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses for human-in-the-loop oversight",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Labels to tag the message for filtering and state management (e.g. ['outreach', 'q4-campaign'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the sent message"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID grouping this message and any future replies"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete sent message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b67469b2-7748-4d81-a223-4ebd332cca89",
|
||||
description="Send a new email from an AgentMail inbox. Creates a new conversation thread. Supports HTML, CC/BCC, and labels.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"to": ["user@example.com"],
|
||||
"subject": "Test",
|
||||
"text": "Hello",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-msg-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"send_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {
|
||||
"message_id": "mock-msg-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_message(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.send(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"to": input_data.to,
|
||||
"subject": input_data.subject,
|
||||
"text": input_data.text,
|
||||
}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
msg = await self.send_message(credentials, input_data.inbox_id, **params)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListMessagesBlock(Block):
|
||||
"""
|
||||
List all messages in an AgentMail inbox with optional label filtering.
|
||||
|
||||
Returns a paginated list of messages. Use labels to filter (e.g.
|
||||
labels=['unread'] to only get unprocessed messages). Useful for
|
||||
polling workflows or building inbox views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list messages from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of messages to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return messages with ALL of these labels (e.g. ['unread'] or ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
messages: list[dict] = SchemaField(
|
||||
description="List of message objects with subject, sender, text, html, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of messages returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="721234df-c7a2-4927-b205-744badbd5844",
|
||||
description="List messages in an AgentMail inbox. Filter by labels to find unread, campaign-tagged, or categorized messages.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("messages", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_messages": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"messages": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_messages(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.list(inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_messages(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
messages = [m.model_dump() for m in response.messages]
|
||||
|
||||
yield "messages", messages
|
||||
yield "count", (c if (c := response.count) is not None else len(messages))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetMessageBlock(Block):
|
||||
"""
|
||||
Retrieve a specific email message by ID from an AgentMail inbox.
|
||||
|
||||
Returns the full message including subject, body (text and HTML),
|
||||
sender, recipients, and attachments. Use extracted_text to get
|
||||
only the new reply content without quoted history.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to retrieve (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="Unique identifier of the message")
|
||||
thread_id: str = SchemaField(description="Thread this message belongs to")
|
||||
subject: str = SchemaField(description="Email subject line")
|
||||
text: str = SchemaField(
|
||||
description="Full plain text body (may include quoted reply history)"
|
||||
)
|
||||
extracted_text: str = SchemaField(
|
||||
description="Just the new reply content with quoted history stripped. Best for AI processing.",
|
||||
default="",
|
||||
)
|
||||
html: str = SchemaField(description="HTML body of the email", default="")
|
||||
result: dict = SchemaField(
|
||||
description="Complete message object with all fields including sender, recipients, attachments, labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2788bdfa-1527-4603-a5e4-a455c05c032f",
|
||||
description="Retrieve a specific email message by ID. Includes extracted_text for clean reply content without quoted history.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("thread_id", "t1"),
|
||||
("subject", "Hi"),
|
||||
("text", "Hello"),
|
||||
("extracted_text", "Hello"),
|
||||
("html", ""),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"thread_id": "t1",
|
||||
"subject": "Hi",
|
||||
"text": "Hello",
|
||||
"extracted_text": "Hello",
|
||||
"html": "",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_message(
|
||||
credentials: APIKeyCredentials,
|
||||
inbox_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.get(
|
||||
inbox_id=inbox_id, message_id=message_id
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
msg = await self.get_message(
|
||||
credentials, input_data.inbox_id, input_data.message_id
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "thread_id", msg.thread_id or ""
|
||||
yield "subject", msg.subject or ""
|
||||
yield "text", msg.text or ""
|
||||
yield "extracted_text", msg.extracted_text or ""
|
||||
yield "html", msg.html or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailReplyToMessageBlock(Block):
|
||||
"""
|
||||
Reply to an existing email message, keeping the reply in the same thread.
|
||||
|
||||
The reply is automatically added to the same conversation thread as the
|
||||
original message. Use this for multi-turn agent conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to send the reply from"
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to reply to (e.g. '<abc123@agentmail.to>')"
|
||||
)
|
||||
text: str = SchemaField(description="Plain text body of the reply")
|
||||
html: str = SchemaField(
|
||||
description="Rich HTML body of the reply",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the reply message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID the reply was added to")
|
||||
result: dict = SchemaField(
|
||||
description="Complete reply message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9fe53fa-5026-4547-9570-b54ccb487229",
|
||||
description="Reply to an existing email in the same conversation thread. Use for multi-turn agent conversations.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"text": "Reply",
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-reply-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"reply_to_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-reply-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-reply-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def reply_to_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.reply(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"text": input_data.text}
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
reply = await self.reply_to_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = reply.model_dump()
|
||||
|
||||
yield "message_id", reply.message_id
|
||||
yield "thread_id", reply.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailForwardMessageBlock(Block):
|
||||
"""
|
||||
Forward an existing email message to one or more recipients.
|
||||
|
||||
Sends the original message content to different email addresses.
|
||||
Optionally prepend additional text or override the subject line.
|
||||
Max 50 combined recipients across to, cc, and bcc.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to forward from"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to forward")
|
||||
to: list[str] = SchemaField(
|
||||
description="Recipient email addresses to forward the message to (e.g. ['user@example.com'])"
|
||||
)
|
||||
cc: list[str] = SchemaField(
|
||||
description="CC recipient email addresses",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
bcc: list[str] = SchemaField(
|
||||
description="BCC recipient email addresses (hidden from other recipients)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Override the subject line (defaults to 'Fwd: <original subject>')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Additional plain text to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
html: str = SchemaField(
|
||||
description="Additional HTML to prepend before the forwarded content",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(
|
||||
description="Unique identifier of the forwarded message"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID of the forward")
|
||||
result: dict = SchemaField(
|
||||
description="Complete forwarded message object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b70c7e33-5d66-4f8e-897f-ac73a7bfce82",
|
||||
description="Forward an email message to one or more recipients. Supports CC/BCC and optional extra text or subject override.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"to": ["user@example.com"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "mock-fwd-id"),
|
||||
("thread_id", "mock-thread-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"forward_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "mock-fwd-id",
|
||||
"thread_id": "mock-thread-id",
|
||||
"model_dump": lambda self: {"message_id": "mock-fwd-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def forward_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.forward(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
|
||||
if total > 50:
|
||||
raise ValueError(
|
||||
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
|
||||
)
|
||||
|
||||
params: dict = {"to": input_data.to}
|
||||
if input_data.cc:
|
||||
params["cc"] = input_data.cc
|
||||
if input_data.bcc:
|
||||
params["bcc"] = input_data.bcc
|
||||
if input_data.subject:
|
||||
params["subject"] = input_data.subject
|
||||
if input_data.text:
|
||||
params["text"] = input_data.text
|
||||
if input_data.html:
|
||||
params["html"] = input_data.html
|
||||
|
||||
fwd = await self.forward_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = fwd.model_dump()
|
||||
|
||||
yield "message_id", fwd.message_id
|
||||
yield "thread_id", fwd.thread_id or ""
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailUpdateMessageBlock(Block):
|
||||
"""
|
||||
Add or remove labels on an email message for state management.
|
||||
|
||||
Labels are string tags used to track message state (read/unread),
|
||||
categorize messages (billing, support), or tag campaigns (q4-outreach).
|
||||
Common pattern: add 'read' and remove 'unread' after processing a message.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the message belongs to"
|
||||
)
|
||||
message_id: str = SchemaField(description="Message ID to update labels on")
|
||||
add_labels: list[str] = SchemaField(
|
||||
description="Labels to add (e.g. ['read', 'processed', 'high-priority'])",
|
||||
default_factory=list,
|
||||
)
|
||||
remove_labels: list[str] = SchemaField(
|
||||
description="Labels to remove (e.g. ['unread', 'pending'])",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
message_id: str = SchemaField(description="The updated message ID")
|
||||
result: dict = SchemaField(
|
||||
description="Complete updated message object with current labels"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="694ff816-4c89-4a5e-a552-8c31be187735",
|
||||
description="Add or remove labels on an email message. Use for read/unread tracking, campaign tagging, or state management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"message_id": "test-msg",
|
||||
"add_labels": ["read"],
|
||||
},
|
||||
test_output=[
|
||||
("message_id", "test-msg"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"update_message": lambda *a, **kw: type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"message_id": "test-msg",
|
||||
"model_dump": lambda self: {"message_id": "test-msg"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
|
||||
):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.messages.update(
|
||||
inbox_id=inbox_id, message_id=message_id, **params
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
if not input_data.add_labels and not input_data.remove_labels:
|
||||
raise ValueError(
|
||||
"Must specify at least one label operation: add_labels or remove_labels"
|
||||
)
|
||||
|
||||
params: dict = {}
|
||||
if input_data.add_labels:
|
||||
params["add_labels"] = input_data.add_labels
|
||||
if input_data.remove_labels:
|
||||
params["remove_labels"] = input_data.remove_labels
|
||||
|
||||
msg = await self.update_message(
|
||||
credentials,
|
||||
input_data.inbox_id,
|
||||
input_data.message_id,
|
||||
**params,
|
||||
)
|
||||
result = msg.model_dump()
|
||||
|
||||
yield "message_id", msg.message_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
651
autogpt_platform/backend/backend/blocks/agent_mail/pods.py
Normal file
@@ -0,0 +1,651 @@
|
||||
"""
|
||||
AgentMail Pod blocks — create, get, list, delete pods and list pod-scoped resources.
|
||||
|
||||
Pods provide multi-tenant isolation between your customers. Each pod acts as
|
||||
an isolated workspace containing its own inboxes, domains, threads, and drafts.
|
||||
Use pods when building SaaS platforms, agency tools, or AI agent fleets that
|
||||
serve multiple customers.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailCreatePodBlock(Block):
|
||||
"""
|
||||
Create a new pod for multi-tenant customer isolation.
|
||||
|
||||
Each pod acts as an isolated workspace for one customer or tenant.
|
||||
Use client_id to map pods to your internal tenant IDs for idempotent
|
||||
creation (safe to retry without creating duplicates).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
client_id: str = SchemaField(
|
||||
description="Your internal tenant/customer ID for idempotent mapping. Lets you access the pod by your own ID instead of AgentMail's pod_id.",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the created pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a2db9784-2d17-4f8f-9d6b-0214e6f22101",
|
||||
description="Create a new pod for multi-tenant customer isolation. Use client_id to map to your internal tenant IDs.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pod_id", "mock-pod-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "mock-pod-id",
|
||||
"model_dump": lambda self: {"pod_id": "mock-pod-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.create(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.client_id:
|
||||
params["client_id"] = input_data.client_id
|
||||
|
||||
pod = await self.create_pod(credentials, **params)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetPodBlock(Block):
|
||||
"""
|
||||
Retrieve details of an existing pod by its ID.
|
||||
|
||||
Returns the pod metadata including its client_id mapping and
|
||||
creation timestamp.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pod_id: str = SchemaField(description="Unique identifier of the pod")
|
||||
result: dict = SchemaField(description="Complete pod object with all metadata")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553361bc-bb1b-4322-9ad4-0c226200217e",
|
||||
description="Retrieve details of an existing pod including its client_id mapping and metadata.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("pod_id", "test-pod"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_pod": lambda *a, **kw: type(
|
||||
"Pod",
|
||||
(),
|
||||
{
|
||||
"pod_id": "test-pod",
|
||||
"model_dump": lambda self: {"pod_id": "test-pod"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.pods.get(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pod = await self.get_pod(credentials, pod_id=input_data.pod_id)
|
||||
result = pod.model_dump()
|
||||
|
||||
yield "pod_id", pod.pod_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodsBlock(Block):
|
||||
"""
|
||||
List all pods in your AgentMail organization.
|
||||
|
||||
Returns a paginated list of all tenant pods with their metadata.
|
||||
Use this to see all customer workspaces at a glance.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of pods to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
pods: list[dict] = SchemaField(
|
||||
description="List of pod objects with pod_id, client_id, creation time, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of pods returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9d3725ee-2968-431a-a816-857ab41e1420",
|
||||
description="List all tenant pods in your organization. See all customer workspaces at a glance.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("pods", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pods": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"pods": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pods(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pods(credentials, **params)
|
||||
pods = [p.model_dump() for p in response.pods]
|
||||
|
||||
yield "pods", pods
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeletePodBlock(Block):
|
||||
"""
|
||||
Permanently delete a pod. All inboxes and domains must be removed first.
|
||||
|
||||
You cannot delete a pod that still contains inboxes or domains.
|
||||
Delete all child resources first, then delete the pod.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(
|
||||
description="Pod ID to permanently delete (must have no inboxes or domains)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the pod was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f371f8cd-682d-4f5f-905c-529c74a8fb35",
|
||||
description="Permanently delete a pod. All inboxes and domains must be removed first.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_pod": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_pod(credentials: APIKeyCredentials, pod_id: str):
|
||||
client = _client(credentials)
|
||||
await client.pods.delete(pod_id=pod_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_pod(credentials, pod_id=input_data.pod_id)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodInboxesBlock(Block):
|
||||
"""
|
||||
List all inboxes within a specific pod (customer workspace).
|
||||
|
||||
Returns only the inboxes belonging to this pod, providing
|
||||
tenant-scoped visibility.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list inboxes from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of inboxes to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inboxes: list[dict] = SchemaField(
|
||||
description="List of inbox objects within this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of inboxes returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a8c17ce0-b7c1-4bc3-ae39-680e1952e5d0",
|
||||
description="List all inboxes within a pod. View email accounts scoped to a specific customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inboxes", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_inboxes": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"inboxes": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_inboxes(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_inboxes(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
inboxes = [i.model_dump() for i in response.inboxes]
|
||||
|
||||
yield "inboxes", inboxes
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads across all inboxes within a pod.
|
||||
|
||||
Returns threads from every inbox in the pod. Use for building
|
||||
per-customer dashboards showing all email activity, or for
|
||||
supervisor agents monitoring a customer's conversations.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list threads from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="80214f08-8b85-4533-a6b8-f8123bfcb410",
|
||||
description="List all conversation threads across all inboxes within a pod. View all email activity for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_threads(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.threads.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_pod_threads(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListPodDraftsBlock(Block):
|
||||
"""
|
||||
List all drafts across all inboxes within a pod.
|
||||
|
||||
Returns pending drafts from every inbox in the pod. Use for
|
||||
per-customer approval dashboards or monitoring scheduled sends.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to list drafts from")
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of drafts to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
drafts: list[dict] = SchemaField(
|
||||
description="List of draft objects from all inboxes in this pod"
|
||||
)
|
||||
count: int = SchemaField(description="Number of drafts returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12fd7a3e-51ad-4b20-97c1-0391f207f517",
|
||||
description="List all drafts across all inboxes within a pod. View pending emails for a customer.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("drafts", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_pod_drafts": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"drafts": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_pod_drafts(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.drafts.list(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
|
||||
response = await self.list_pod_drafts(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
drafts = [d.model_dump() for d in response.drafts]
|
||||
|
||||
yield "drafts", drafts
|
||||
yield "count", response.count
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailCreatePodInboxBlock(Block):
|
||||
"""
|
||||
Create a new email inbox within a specific pod (customer workspace).
|
||||
|
||||
The inbox is automatically scoped to the pod and inherits its
|
||||
isolation guarantees. If username/domain are not provided,
|
||||
AgentMail auto-generates a unique address.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
pod_id: str = SchemaField(description="Pod ID to create the inbox in")
|
||||
username: str = SchemaField(
|
||||
description="Local part of the email address (e.g. 'support'). Leave empty to auto-generate.",
|
||||
default="",
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
|
||||
default="",
|
||||
)
|
||||
display_name: str = SchemaField(
|
||||
description="Friendly name shown in the 'From' field (e.g. 'Customer Support')",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
inbox_id: str = SchemaField(
|
||||
description="Unique identifier of the created inbox"
|
||||
)
|
||||
email_address: str = SchemaField(description="Full email address of the inbox")
|
||||
result: dict = SchemaField(
|
||||
description="Complete inbox object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6862373-1ac6-402e-89e6-7db1fea882af",
|
||||
description="Create a new email inbox within a pod. The inbox is scoped to the customer workspace.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
|
||||
test_output=[
|
||||
("inbox_id", "mock-inbox-id"),
|
||||
("email_address", "mock-inbox-id"),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"create_pod_inbox": lambda *a, **kw: type(
|
||||
"Inbox",
|
||||
(),
|
||||
{
|
||||
"inbox_id": "mock-inbox-id",
|
||||
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_pod_inbox(credentials: APIKeyCredentials, pod_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.pods.inboxes.create(pod_id=pod_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {}
|
||||
if input_data.username:
|
||||
params["username"] = input_data.username
|
||||
if input_data.domain:
|
||||
params["domain"] = input_data.domain
|
||||
if input_data.display_name:
|
||||
params["display_name"] = input_data.display_name
|
||||
|
||||
inbox = await self.create_pod_inbox(
|
||||
credentials, pod_id=input_data.pod_id, **params
|
||||
)
|
||||
result = inbox.model_dump()
|
||||
|
||||
yield "inbox_id", inbox.inbox_id
|
||||
yield "email_address", inbox.inbox_id
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
438
autogpt_platform/backend/backend/blocks/agent_mail/threads.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
AgentMail Thread blocks — list, get, and delete conversation threads.
|
||||
|
||||
A Thread groups related messages into a single conversation. Threads are
|
||||
created automatically when a new message is sent and grow as replies are added.
|
||||
Threads can be queried per-inbox or across the entire organization.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
|
||||
|
||||
|
||||
class AgentMailListInboxThreadsBlock(Block):
|
||||
"""
|
||||
List all conversation threads within a specific AgentMail inbox.
|
||||
|
||||
Returns a paginated list of threads with optional label filtering.
|
||||
Use labels to find threads by campaign, status, or custom tags.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address to list threads from"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels (e.g. ['q4-campaign', 'follow-up'])",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects with thread_id, subject, message count, labels, etc."
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="63dd9e2d-ef81-405c-b034-c031f0437334",
|
||||
description="List all conversation threads in an AgentMail inbox. Filter by labels for campaign tracking or status management.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_threads(credentials: APIKeyCredentials, inbox_id: str, **params):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.list(inbox_id=inbox_id, **params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_threads(
|
||||
credentials, input_data.inbox_id, **params
|
||||
)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetInboxThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread from an AgentMail inbox.
|
||||
|
||||
Returns the thread with all its messages in chronological order.
|
||||
Use this to get the full conversation history for context when
|
||||
composing replies.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to retrieve")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42866290-1479-4153-83e7-550b703e9da2",
|
||||
description="Retrieve a conversation thread with all its messages. Use for getting full conversation context before replying.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_thread(credentials: APIKeyCredentials, inbox_id: str, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.inboxes.threads.get(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailDeleteInboxThreadBlock(Block):
|
||||
"""
|
||||
Permanently delete a conversation thread and all its messages from an inbox.
|
||||
|
||||
This removes the thread and every message within it. This action
|
||||
cannot be undone.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
inbox_id: str = SchemaField(
|
||||
description="Inbox ID or email address the thread belongs to"
|
||||
)
|
||||
thread_id: str = SchemaField(description="Thread ID to permanently delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="True if the thread was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="18cd5f6f-4ff6-45da-8300-25a50ea7fb75",
|
||||
description="Permanently delete a conversation thread and all its messages. This action cannot be undone.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
is_sensitive_action=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"inbox_id": "test-inbox",
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[("success", True)],
|
||||
test_mock={
|
||||
"delete_thread": lambda *a, **kw: None,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_thread(
|
||||
credentials: APIKeyCredentials, inbox_id: str, thread_id: str
|
||||
):
|
||||
client = _client(credentials)
|
||||
await client.inboxes.threads.delete(inbox_id=inbox_id, thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
await self.delete_thread(
|
||||
credentials, input_data.inbox_id, input_data.thread_id
|
||||
)
|
||||
yield "success", True
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailListOrgThreadsBlock(Block):
|
||||
"""
|
||||
List conversation threads across ALL inboxes in your organization.
|
||||
|
||||
Unlike per-inbox listing, this returns threads from every inbox.
|
||||
Ideal for building supervisor agents that monitor all conversations,
|
||||
analytics dashboards, or cross-agent routing workflows.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of threads to return per page (1-100)",
|
||||
default=20,
|
||||
advanced=True,
|
||||
)
|
||||
page_token: str = SchemaField(
|
||||
description="Token from a previous response to fetch the next page",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
labels: list[str] = SchemaField(
|
||||
description="Only return threads matching ALL of these labels",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
threads: list[dict] = SchemaField(
|
||||
description="List of thread objects from all inboxes in the organization"
|
||||
)
|
||||
count: int = SchemaField(description="Number of threads returned")
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for the next page. Empty if no more results.",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a0657b-58ab-48b2-898b-7bd94f44a708",
|
||||
description="List threads across ALL inboxes in your organization. Use for supervisor agents, dashboards, or cross-agent monitoring.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
("threads", []),
|
||||
("count", 0),
|
||||
("next_page_token", ""),
|
||||
],
|
||||
test_mock={
|
||||
"list_org_threads": lambda *a, **kw: type(
|
||||
"Resp",
|
||||
(),
|
||||
{
|
||||
"threads": [],
|
||||
"count": 0,
|
||||
"next_page_token": "",
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_org_threads(credentials: APIKeyCredentials, **params):
|
||||
client = _client(credentials)
|
||||
return await client.threads.list(**params)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
params: dict = {"limit": input_data.limit}
|
||||
if input_data.page_token:
|
||||
params["page_token"] = input_data.page_token
|
||||
if input_data.labels:
|
||||
params["labels"] = input_data.labels
|
||||
|
||||
response = await self.list_org_threads(credentials, **params)
|
||||
threads = [t.model_dump() for t in response.threads]
|
||||
|
||||
yield "threads", threads
|
||||
yield "count", (c if (c := response.count) is not None else len(threads))
|
||||
yield "next_page_token", response.next_page_token or ""
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class AgentMailGetOrgThreadBlock(Block):
|
||||
"""
|
||||
Retrieve a single conversation thread by ID from anywhere in the organization.
|
||||
|
||||
Works without needing to know which inbox the thread belongs to.
|
||||
Returns the thread with all its messages in chronological order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = agent_mail.credentials_field(
|
||||
description="AgentMail API key from https://console.agentmail.to"
|
||||
)
|
||||
thread_id: str = SchemaField(
|
||||
description="Thread ID to retrieve (works across all inboxes)"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
thread_id: str = SchemaField(description="Unique identifier of the thread")
|
||||
messages: list[dict] = SchemaField(
|
||||
description="All messages in the thread, in chronological order"
|
||||
)
|
||||
result: dict = SchemaField(
|
||||
description="Complete thread object with all metadata"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39aaae31-3eb1-44c6-9e37-5a44a4529649",
|
||||
description="Retrieve a conversation thread by ID from anywhere in the organization, without needing the inbox ID.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"thread_id": "test-thread",
|
||||
},
|
||||
test_output=[
|
||||
("thread_id", "test-thread"),
|
||||
("messages", []),
|
||||
("result", dict),
|
||||
],
|
||||
test_mock={
|
||||
"get_org_thread": lambda *a, **kw: type(
|
||||
"Thread",
|
||||
(),
|
||||
{
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
"model_dump": lambda self: {
|
||||
"thread_id": "test-thread",
|
||||
"messages": [],
|
||||
},
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_thread(credentials: APIKeyCredentials, thread_id: str):
|
||||
client = _client(credentials)
|
||||
return await client.threads.get(thread_id=thread_id)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
thread = await self.get_org_thread(credentials, input_data.thread_id)
|
||||
messages = [m.model_dump() for m in thread.messages]
|
||||
result = thread.model_dump()
|
||||
result["messages"] = messages
|
||||
|
||||
yield "thread_id", thread.thread_id
|
||||
yield "messages", messages
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -19,6 +20,33 @@ from backend.blocks.llm import (
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
# Minimum max_output_tokens accepted by OpenAI-compatible APIs.
|
||||
# A true/false answer fits comfortably within this budget.
|
||||
MIN_LLM_OUTPUT_TOKENS = 16
|
||||
|
||||
|
||||
def _parse_boolean_response(response_text: str) -> tuple[bool, str | None]:
|
||||
"""Parse an LLM response into a boolean result.
|
||||
|
||||
Returns a ``(result, error)`` tuple. *error* is ``None`` when the
|
||||
response is unambiguous; otherwise it contains a diagnostic message
|
||||
and *result* defaults to ``False``.
|
||||
"""
|
||||
text = response_text.strip().lower()
|
||||
if text == "true":
|
||||
return True, None
|
||||
if text == "false":
|
||||
return False, None
|
||||
|
||||
# Fuzzy match – use word boundaries to avoid false positives like "untrue".
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", text))
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
return True, None
|
||||
if tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
return False, None
|
||||
|
||||
return False, f"Unclear AI response: '{response_text}'"
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
@@ -162,54 +190,26 @@ class AIConditionBlock(AIBlockBase):
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=MIN_LLM_OUTPUT_TOKENS,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
result, error = _parse_boolean_response(response.response)
|
||||
if error:
|
||||
yield "error", error
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
147
autogpt_platform/backend/backend/blocks/ai_condition_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for AIConditionBlock – regression coverage for max_tokens and error propagation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.ai_condition import (
|
||||
MIN_LLM_OUTPUT_TOKENS,
|
||||
AIConditionBlock,
|
||||
_parse_boolean_response,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AICredentials,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
_TEST_AI_CREDENTIALS = cast(AICredentials, TEST_CREDENTIALS_INPUT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to collect all yields from the async generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _collect_outputs(block: AIConditionBlock, input_data, credentials):
|
||||
outputs: dict[str, object] = {}
|
||||
async for name, value in block.run(input_data, credentials=credentials):
|
||||
outputs[name] = value
|
||||
return outputs
|
||||
|
||||
|
||||
def _make_input(**overrides) -> AIConditionBlock.Input:
|
||||
defaults: dict = {
|
||||
"input_value": "hello@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "yes!",
|
||||
"no_value": "no!",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return AIConditionBlock.Input(**defaults)
|
||||
|
||||
|
||||
def _mock_llm_response(response_text: str) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response=response_text,
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_boolean_response unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseBooleanResponse:
|
||||
def test_true_exact(self):
|
||||
assert _parse_boolean_response("true") == (True, None)
|
||||
|
||||
def test_false_exact(self):
|
||||
assert _parse_boolean_response("false") == (False, None)
|
||||
|
||||
def test_true_with_whitespace(self):
|
||||
assert _parse_boolean_response(" True ") == (True, None)
|
||||
|
||||
def test_yes_fuzzy(self):
|
||||
assert _parse_boolean_response("Yes") == (True, None)
|
||||
|
||||
def test_no_fuzzy(self):
|
||||
assert _parse_boolean_response("no") == (False, None)
|
||||
|
||||
def test_one_fuzzy(self):
|
||||
assert _parse_boolean_response("1") == (True, None)
|
||||
|
||||
def test_zero_fuzzy(self):
|
||||
assert _parse_boolean_response("0") == (False, None)
|
||||
|
||||
def test_unclear_response(self):
|
||||
result, error = _parse_boolean_response("I'm not sure")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
assert "Unclear" in error
|
||||
|
||||
def test_conflicting_tokens(self):
|
||||
result, error = _parse_boolean_response("true and false")
|
||||
assert result is False
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: max_tokens is set to MIN_LLM_OUTPUT_TOKENS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxTokensRegression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_receives_min_output_tokens(self):
|
||||
"""max_tokens must be MIN_LLM_OUTPUT_TOKENS (16) – the previous value
|
||||
of 1 was too low and caused OpenAI to reject the request."""
|
||||
block = AIConditionBlock()
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def spy_llm_call(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_llm_response("true")
|
||||
|
||||
block.llm_call = spy_llm_call # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
|
||||
assert captured_kwargs["max_tokens"] == MIN_LLM_OUTPUT_TOKENS
|
||||
assert captured_kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: exceptions from llm_call must propagate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionPropagation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_exception_propagates(self):
|
||||
"""If llm_call raises, the exception must NOT be swallowed.
|
||||
Previously the block caught all exceptions and silently returned
|
||||
result=False."""
|
||||
block = AIConditionBlock()
|
||||
|
||||
async def boom(**kwargs):
|
||||
raise RuntimeError("LLM provider error")
|
||||
|
||||
block.llm_call = boom # type: ignore[assignment]
|
||||
|
||||
input_data = _make_input()
|
||||
with pytest.raises(RuntimeError, match="LLM provider error"):
|
||||
await _collect_outputs(block, input_data, credentials=TEST_CREDENTIALS)
|
||||
@@ -27,6 +27,7 @@ from backend.util.file import MediaFileType, store_media_file
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
NANO_BANANA_2 = "google/nano-banana-2"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -77,7 +78,7 @@ class AIImageCustomizerBlock(Block):
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
default=GeminiImageModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
@@ -103,7 +104,7 @@ class AIImageCustomizerBlock(Block):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Generate and edit custom images using Google's Nano-Banana models from Gemini. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -111,7 +112,7 @@ class AIImageCustomizerBlock(Block):
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"model": GeminiImageModel.NANO_BANANA_2,
|
||||
"images": [],
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"output_format": OutputFormat.JPG,
|
||||
|
||||
@@ -115,6 +115,7 @@ class ImageGenModel(str, Enum):
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
@@ -131,7 +132,7 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
model: ImageGenModel = SchemaField(
|
||||
description="The AI model to use for image generation",
|
||||
default=ImageGenModel.SD3_5,
|
||||
default=ImageGenModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
size: ImageSize = SchemaField(
|
||||
@@ -165,7 +166,7 @@ class AIImageGeneratorBlock(Block):
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "An octopus using a laptop in a snowy forest with 'AutoGPT' clearly visible on the screen",
|
||||
"model": ImageGenModel.RECRAFT,
|
||||
"model": ImageGenModel.NANO_BANANA_2,
|
||||
"size": ImageSize.SQUARE,
|
||||
"style": ImageStyle.REALISTIC,
|
||||
},
|
||||
@@ -179,7 +180,9 @@ class AIImageGeneratorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Return a data URI directly so store_media_file doesn't need to download
|
||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
"_run_client": lambda *args, **kwargs: (
|
||||
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -280,17 +283,24 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
elif input_data.model in (
|
||||
ImageGenModel.NANO_BANANA_PRO,
|
||||
ImageGenModel.NANO_BANANA_2,
|
||||
):
|
||||
# Use Nano Banana models (Google Gemini image variants)
|
||||
model_map = {
|
||||
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
|
||||
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
|
||||
}
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"resolution": "2K",
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
credentials, model_map[input_data.model], input_params
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
537
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
537
autogpt_platform/backend/backend/blocks/autopilot.py
Normal file
@@ -0,0 +1,537 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.copilot.permissions import (
|
||||
CopilotPermissions,
|
||||
ToolName,
|
||||
all_known_tool_names,
|
||||
validate_block_identifiers,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Block ID shared between autopilot.py and copilot prompting.py.
|
||||
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
|
||||
|
||||
|
||||
class ToolCallEntry(TypedDict):
|
||||
"""A single tool invocation record from an autopilot execution."""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
input: Any
|
||||
output: Any | None
|
||||
success: bool | None
|
||||
|
||||
|
||||
class TokenUsage(TypedDict):
|
||||
"""Aggregated token counts from the autopilot stream."""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AutoPilotBlock(Block):
|
||||
"""Execute tasks using AutoGPT AutoPilot with full access to platform tools.
|
||||
|
||||
The autopilot can manage agents, access workspace files, fetch web content,
|
||||
run blocks, and more. This block enables sub-agent patterns (autopilot calling
|
||||
autopilot) and scheduled autopilot execution via the agent executor.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
"""Input schema for the AutoPilot block."""
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for the autopilot to execute. "
|
||||
"The autopilot has access to platform tools like agent management, "
|
||||
"workspace files, web fetch, block execution, and more."
|
||||
),
|
||||
placeholder="Find my agents and list them",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
system_context: str = SchemaField(
|
||||
description=(
|
||||
"Optional additional context prepended to the prompt. "
|
||||
"Use this to constrain autopilot behavior, provide domain "
|
||||
"context, or set output format requirements."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to continue an existing autopilot conversation. "
|
||||
"Leave empty to start a new session. "
|
||||
"Use the session_id output from a previous run to continue."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
max_recursion_depth: int = SchemaField(
|
||||
description=(
|
||||
"Maximum nesting depth when the autopilot calls this block "
|
||||
"recursively (sub-agent pattern). Prevents infinite loops."
|
||||
),
|
||||
default=3,
|
||||
ge=1,
|
||||
le=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools: list[ToolName] = SchemaField(
|
||||
description=(
|
||||
"Tool names to filter. Works with tools_exclude to form an "
|
||||
"allow-list or deny-list. "
|
||||
"Leave empty to apply no tool filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tools_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'tools' list is interpreted. "
|
||||
"True (default): 'tools' is a deny-list — listed tools are blocked, "
|
||||
"all others are allowed. An empty 'tools' list means allow everything. "
|
||||
"False: 'tools' is an allow-list — only listed tools are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks: list[str] = SchemaField(
|
||||
description=(
|
||||
"Block identifiers to filter when the copilot uses run_block. "
|
||||
"Each entry can be: a block name (e.g. 'HTTP Request'), "
|
||||
"a full block UUID, or the first 8 hex characters of the UUID "
|
||||
"(e.g. 'c069dc6b'). Works with blocks_exclude. "
|
||||
"Leave empty to apply no block filter."
|
||||
),
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
blocks_exclude: bool = SchemaField(
|
||||
description=(
|
||||
"Controls how the 'blocks' list is interpreted. "
|
||||
"True (default): 'blocks' is a deny-list — listed blocks are blocked, "
|
||||
"all others are allowed. An empty 'blocks' list means allow everything. "
|
||||
"False: 'blocks' is an allow-list — only listed blocks are permitted."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dry_run: bool = SchemaField(
|
||||
description=(
|
||||
"When enabled, run_block and run_agent tool calls in this "
|
||||
"autopilot session are forced to use dry-run simulation mode. "
|
||||
"No real API calls, side effects, or credits are consumed "
|
||||
"by those tools. Useful for testing agent wiring and "
|
||||
"previewing outputs. "
|
||||
"Only applies when creating a new session (session_id is empty). "
|
||||
"When reusing an existing session_id, the session's original "
|
||||
"dry_run setting is preserved."
|
||||
),
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# timeout_seconds removed: the SDK manages its own heartbeat-based
|
||||
# timeouts internally; wrapping with asyncio.timeout corrupts the
|
||||
# SDK's internal stream (see service.py CRITICAL comment).
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
"""Output schema for the AutoPilot block."""
|
||||
|
||||
response: str = SchemaField(
|
||||
description="The final text response from the autopilot."
|
||||
)
|
||||
tool_calls: list[ToolCallEntry] = SchemaField(
|
||||
description=(
|
||||
"List of tools called during execution. Each entry has "
|
||||
"tool_call_id, tool_name, input, output, and success fields."
|
||||
),
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Current turn messages (user prompt + assistant reply) as JSON. "
|
||||
"It can be used for logging or analysis."
|
||||
),
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back to continue the conversation in a future run."
|
||||
),
|
||||
)
|
||||
token_usage: TokenUsage = SchemaField(
|
||||
description=(
|
||||
"Token usage statistics: prompt_tokens, "
|
||||
"completion_tokens, total_tokens."
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=AUTOPILOT_BLOCK_ID,
|
||||
description=(
|
||||
"Execute tasks using AutoGPT AutoPilot with full access to "
|
||||
"platform tools (agent management, workspace files, web fetch, "
|
||||
"block execution, and more). Enables sub-agent patterns and "
|
||||
"scheduled autopilot execution."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.AGENT},
|
||||
input_schema=AutoPilotBlock.Input,
|
||||
output_schema=AutoPilotBlock.Output,
|
||||
test_input={
|
||||
"prompt": "List my agents",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
},
|
||||
test_output=[
|
||||
("response", "You have 2 agents: Agent A and Agent B."),
|
||||
("tool_calls", []),
|
||||
(
|
||||
"conversation_history",
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
),
|
||||
("session_id", "test-session-id"),
|
||||
(
|
||||
"token_usage",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_session": lambda *args, **kwargs: "test-session-id",
|
||||
"execute_copilot": lambda *args, **kwargs: (
|
||||
"You have 2 agents: Agent A and Agent B.",
|
||||
[],
|
||||
'[{"role": "user", "content": "List my agents"}]',
|
||||
"test-session-id",
|
||||
{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def create_session(self, user_id: str, *, dry_run: bool) -> str:
|
||||
"""Create a new chat session and return its ID (mockable for tests)."""
|
||||
from backend.copilot.model import create_chat_session # avoid circular import
|
||||
|
||||
session = await create_chat_session(user_id, dry_run=dry_run)
|
||||
return session.session_id
|
||||
|
||||
async def execute_copilot(
|
||||
self,
|
||||
prompt: str,
|
||||
system_context: str,
|
||||
session_id: str,
|
||||
max_recursion_depth: int,
|
||||
user_id: str,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
|
||||
"""Invoke the copilot and collect all stream results.
|
||||
|
||||
Delegates to :func:`collect_copilot_response` — the shared helper that
|
||||
consumes ``stream_chat_completion_sdk`` without wrapping it in an
|
||||
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
|
||||
|
||||
Args:
|
||||
prompt: The user task/instruction.
|
||||
system_context: Optional context prepended to the prompt.
|
||||
session_id: Chat session to use.
|
||||
max_recursion_depth: Maximum allowed recursion nesting.
|
||||
user_id: Authenticated user ID.
|
||||
permissions: Optional capability filter restricting tools/blocks.
|
||||
|
||||
Returns:
|
||||
A tuple of (response_text, tool_calls, history_json, session_id, usage).
|
||||
"""
|
||||
from backend.copilot.sdk.collect import (
|
||||
collect_copilot_response, # avoid circular import
|
||||
)
|
||||
|
||||
tokens = _check_recursion(max_recursion_depth)
|
||||
perm_token = None
|
||||
try:
|
||||
effective_permissions, perm_token = _merge_inherited_permissions(
|
||||
permissions
|
||||
)
|
||||
effective_prompt = prompt
|
||||
if system_context:
|
||||
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
|
||||
|
||||
result = await collect_copilot_response(
|
||||
session_id=session_id,
|
||||
message=effective_prompt,
|
||||
user_id=user_id,
|
||||
permissions=effective_permissions,
|
||||
)
|
||||
|
||||
# Build a lightweight conversation summary from streamed data.
|
||||
turn_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": effective_prompt},
|
||||
]
|
||||
if result.tool_calls:
|
||||
turn_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result.response_text,
|
||||
"tool_calls": result.tool_calls,
|
||||
}
|
||||
)
|
||||
else:
|
||||
turn_messages.append(
|
||||
{"role": "assistant", "content": result.response_text}
|
||||
)
|
||||
history_json = json.dumps(turn_messages, default=str)
|
||||
|
||||
tool_calls: list[ToolCallEntry] = [
|
||||
{
|
||||
"tool_call_id": tc["tool_call_id"],
|
||||
"tool_name": tc["tool_name"],
|
||||
"input": tc["input"],
|
||||
"output": tc["output"],
|
||||
"success": tc["success"],
|
||||
}
|
||||
for tc in result.tool_calls
|
||||
]
|
||||
|
||||
usage: TokenUsage = {
|
||||
"prompt_tokens": result.prompt_tokens,
|
||||
"completion_tokens": result.completion_tokens,
|
||||
"total_tokens": result.total_tokens,
|
||||
}
|
||||
|
||||
return (
|
||||
result.response_text,
|
||||
tool_calls,
|
||||
history_json,
|
||||
session_id,
|
||||
usage,
|
||||
)
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
if perm_token is not None:
|
||||
_inherited_permissions.reset(perm_token)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Validate inputs, invoke the autopilot, and yield structured outputs.
|
||||
|
||||
Yields session_id even on failure so callers can inspect/resume the session.
|
||||
"""
|
||||
if not input_data.prompt.strip():
|
||||
yield "error", "Prompt cannot be empty."
|
||||
return
|
||||
|
||||
if not execution_context.user_id:
|
||||
yield "error", "Cannot run autopilot without an authenticated user."
|
||||
return
|
||||
|
||||
if input_data.max_recursion_depth < 1:
|
||||
yield "error", "max_recursion_depth must be at least 1."
|
||||
return
|
||||
|
||||
# Validate and build permissions eagerly — fail before creating a session.
|
||||
permissions = await _build_and_validate_permissions(input_data)
|
||||
if isinstance(permissions, str):
|
||||
# Validation error returned as a string message.
|
||||
yield "error", permissions
|
||||
return
|
||||
|
||||
# Create session eagerly so the user always gets the session_id,
|
||||
# even if the downstream stream fails (avoids orphaned sessions).
|
||||
sid = input_data.session_id
|
||||
if not sid:
|
||||
sid = await self.create_session(
|
||||
execution_context.user_id, dry_run=input_data.dry_run
|
||||
)
|
||||
|
||||
# NOTE: No asyncio.timeout() here — the SDK manages its own
|
||||
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
|
||||
# would cancel the task mid-flight, corrupting the SDK's internal
|
||||
# anyio memory stream (see service.py CRITICAL comment).
|
||||
try:
|
||||
response, tool_calls, history, _, usage = await self.execute_copilot(
|
||||
prompt=input_data.prompt,
|
||||
system_context=input_data.system_context,
|
||||
session_id=sid,
|
||||
max_recursion_depth=input_data.max_recursion_depth,
|
||||
user_id=execution_context.user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
yield "tool_calls", tool_calls
|
||||
yield "conversation_history", history
|
||||
yield "session_id", sid
|
||||
yield "token_usage", usage
|
||||
except asyncio.CancelledError:
|
||||
yield "session_id", sid
|
||||
yield "error", "AutoPilot execution was cancelled."
|
||||
raise
|
||||
except Exception as exc:
|
||||
yield "session_id", sid
|
||||
yield "error", str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – placed after the block class for top-down readability.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Task-scoped recursion depth counter & chain-wide limit.
|
||||
# contextvars are scoped to the current asyncio task, so concurrent
|
||||
# graph executions each get independent counters.
|
||||
_autopilot_recursion_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_depth", default=0
|
||||
)
|
||||
_autopilot_recursion_limit: contextvars.ContextVar[int | None] = contextvars.ContextVar(
|
||||
"_autopilot_recursion_limit", default=None
|
||||
)
|
||||
|
||||
|
||||
def _check_recursion(
|
||||
max_depth: int,
|
||||
) -> tuple[contextvars.Token[int], contextvars.Token[int | None]]:
|
||||
"""Check and increment recursion depth.
|
||||
|
||||
Returns ContextVar tokens that must be passed to ``_reset_recursion``
|
||||
when the caller exits to restore the previous depth.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the current depth already meets or exceeds the limit.
|
||||
"""
|
||||
current = _autopilot_recursion_depth.get()
|
||||
inherited = _autopilot_recursion_limit.get()
|
||||
limit = max_depth if inherited is None else min(inherited, max_depth)
|
||||
if current >= limit:
|
||||
raise RuntimeError(
|
||||
f"AutoPilot recursion depth limit reached ({limit}). "
|
||||
"The autopilot has called itself too many times."
|
||||
)
|
||||
return (
|
||||
_autopilot_recursion_depth.set(current + 1),
|
||||
_autopilot_recursion_limit.set(limit),
|
||||
)
|
||||
|
||||
|
||||
def _reset_recursion(
|
||||
tokens: tuple[contextvars.Token[int], contextvars.Token[int | None]],
|
||||
) -> None:
|
||||
"""Restore recursion depth and limit to their previous values."""
|
||||
_autopilot_recursion_depth.reset(tokens[0])
|
||||
_autopilot_recursion_limit.reset(tokens[1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Inherited permissions from a parent AutoPilotBlock execution.
|
||||
# This acts as a ceiling: child executions can only be more restrictive.
|
||||
_inherited_permissions: contextvars.ContextVar["CopilotPermissions | None"] = (
|
||||
contextvars.ContextVar("_inherited_permissions", default=None)
|
||||
)
|
||||
|
||||
|
||||
async def _build_and_validate_permissions(
|
||||
input_data: "AutoPilotBlock.Input",
|
||||
) -> "CopilotPermissions | str":
|
||||
"""Build a :class:`CopilotPermissions` from block input and validate it.
|
||||
|
||||
Returns a :class:`CopilotPermissions` on success or a human-readable
|
||||
error string if validation fails.
|
||||
"""
|
||||
# Tool names are validated by Pydantic via the ToolName Literal type
|
||||
# at model construction time — no runtime check needed here.
|
||||
# Validate block identifiers against live block registry.
|
||||
if input_data.blocks:
|
||||
invalid_blocks = await validate_block_identifiers(input_data.blocks)
|
||||
if invalid_blocks:
|
||||
return (
|
||||
f"Unknown block identifier(s) in 'blocks': {invalid_blocks}. "
|
||||
"Use find_block to discover valid block names and IDs. "
|
||||
"You may also use the first 8 characters of a block UUID."
|
||||
)
|
||||
|
||||
return CopilotPermissions(
|
||||
tools=list(input_data.tools),
|
||||
tools_exclude=input_data.tools_exclude,
|
||||
blocks=input_data.blocks,
|
||||
blocks_exclude=input_data.blocks_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _merge_inherited_permissions(
|
||||
permissions: "CopilotPermissions | None",
|
||||
) -> "tuple[CopilotPermissions | None, contextvars.Token[CopilotPermissions | None] | None]":
|
||||
"""Merge *permissions* with any inherited parent permissions.
|
||||
|
||||
The merged result is stored back into the contextvar so that any nested
|
||||
AutoPilotBlock invocation (sub-agent) inherits the merged ceiling.
|
||||
|
||||
Returns a tuple of (merged_permissions, reset_token). The caller MUST
|
||||
reset the contextvar via ``_inherited_permissions.reset(token)`` in a
|
||||
``finally`` block when ``reset_token`` is not None — this prevents
|
||||
permission leakage between sequential independent executions in the same
|
||||
asyncio task.
|
||||
"""
|
||||
parent = _inherited_permissions.get()
|
||||
|
||||
if permissions is None and parent is None:
|
||||
return None, None
|
||||
|
||||
all_tools = all_known_tool_names()
|
||||
|
||||
if permissions is None:
|
||||
permissions = CopilotPermissions() # allow-all; will be narrowed by parent
|
||||
|
||||
merged = (
|
||||
permissions.merged_with_parent(parent, all_tools)
|
||||
if parent is not None
|
||||
else permissions
|
||||
)
|
||||
|
||||
# Store merged permissions as the new inherited ceiling for nested calls.
|
||||
# Return the token so the caller can restore the previous value in finally.
|
||||
token = _inherited_permissions.set(merged)
|
||||
return merged, token
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Tests for AutoPilotBlock permission fields and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AutoPilotBlock,
|
||||
_build_and_validate_permissions,
|
||||
_inherited_permissions,
|
||||
_merge_inherited_permissions,
|
||||
)
|
||||
from backend.copilot.permissions import CopilotPermissions, all_known_tool_names
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_input(**kwargs) -> AutoPilotBlock.Input:
|
||||
defaults = {
|
||||
"prompt": "Do something",
|
||||
"system_context": "",
|
||||
"session_id": "",
|
||||
"max_recursion_depth": 3,
|
||||
"tools": [],
|
||||
"tools_exclude": True,
|
||||
"blocks": [],
|
||||
"blocks_exclude": True,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AutoPilotBlock.Input(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_and_validate_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBuildAndValidatePermissions:
|
||||
async def test_empty_inputs_returns_empty_permissions(self):
|
||||
inp = _make_input()
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.is_empty()
|
||||
|
||||
async def test_valid_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["run_block", "web_fetch"], tools_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.tools == ["run_block", "web_fetch"]
|
||||
assert result.tools_exclude is True
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are now caught at Pydantic validation time
|
||||
(Literal type), before ``_build_and_validate_permissions`` is called."""
|
||||
with pytest.raises(ValidationError, match="not_a_real_tool"):
|
||||
_make_input(tools=["not_a_real_tool"])
|
||||
|
||||
async def test_valid_block_name_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["HTTP Request"], blocks_exclude=True)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert result.blocks == ["HTTP Request"]
|
||||
|
||||
async def test_valid_partial_uuid_accepted(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["c069dc6b"], blocks_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
|
||||
async def test_invalid_block_identifier_returns_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
inp = _make_input(blocks=["totally_fake_block"])
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, str)
|
||||
assert "totally_fake_block" in result
|
||||
assert "Unknown block identifier" in result
|
||||
|
||||
async def test_sdk_builtin_tool_names_accepted(self):
|
||||
inp = _make_input(tools=["Read", "Task", "WebSearch"], tools_exclude=False)
|
||||
result = await _build_and_validate_permissions(inp)
|
||||
assert isinstance(result, CopilotPermissions)
|
||||
assert not result.tools_exclude
|
||||
|
||||
async def test_empty_blocks_skips_validation(self):
|
||||
# Should not call validate_block_identifiers at all when blocks=[].
|
||||
with patch(
|
||||
"backend.copilot.permissions.validate_block_identifiers"
|
||||
) as mock_validate:
|
||||
inp = _make_input(blocks=[])
|
||||
await _build_and_validate_permissions(inp)
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_inherited_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeInheritedPermissions:
|
||||
def test_no_permissions_no_parent_returns_none(self):
|
||||
merged, token = _merge_inherited_permissions(None)
|
||||
assert merged is None
|
||||
assert token is None
|
||||
|
||||
def test_permissions_no_parent_returned_unchanged(self):
|
||||
perms = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
merged, token = _merge_inherited_permissions(perms)
|
||||
try:
|
||||
assert merged is perms
|
||||
assert token is not None
|
||||
finally:
|
||||
if token is not None:
|
||||
_inherited_permissions.reset(token)
|
||||
|
||||
def test_child_narrows_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
# Set parent as inherited
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
child = CopilotPermissions(tools=["web_fetch"], tools_exclude=True)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
all_t = all_known_tool_names()
|
||||
effective = merged.effective_allowed_tools(all_t)
|
||||
assert "bash_exec" not in effective
|
||||
assert "web_fetch" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_none_permissions_with_parent_uses_parent(self):
|
||||
parent = CopilotPermissions(tools=["bash_exec"], tools_exclude=True)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
merged, inner_token = _merge_inherited_permissions(None)
|
||||
try:
|
||||
assert merged is not None
|
||||
# Merged should have parent's restrictions
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
def test_child_cannot_expand_parent_whitelist(self):
|
||||
parent = CopilotPermissions(tools=["run_block"], tools_exclude=False)
|
||||
outer_token = _inherited_permissions.set(parent)
|
||||
try:
|
||||
# Child tries to allow more tools
|
||||
child = CopilotPermissions(
|
||||
tools=["run_block", "bash_exec"], tools_exclude=False
|
||||
)
|
||||
merged, inner_token = _merge_inherited_permissions(child)
|
||||
try:
|
||||
assert merged is not None
|
||||
effective = merged.effective_allowed_tools(all_known_tool_names())
|
||||
assert "bash_exec" not in effective
|
||||
assert "run_block" in effective
|
||||
finally:
|
||||
if inner_token is not None:
|
||||
_inherited_permissions.reset(inner_token)
|
||||
finally:
|
||||
_inherited_permissions.reset(outer_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run — validation integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoPilotBlockRunPermissions:
|
||||
async def _collect_outputs(self, block, input_data, user_id="test-user"):
|
||||
"""Helper to collect all yields from block.run()."""
|
||||
ctx = ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="g1",
|
||||
graph_exec_id="ge1",
|
||||
node_exec_id="ne1",
|
||||
node_id="n1",
|
||||
)
|
||||
outputs = {}
|
||||
async for key, val in block.run(input_data, execution_context=ctx):
|
||||
outputs[key] = val
|
||||
return outputs
|
||||
|
||||
async def test_invalid_tool_rejected_by_pydantic(self):
|
||||
"""Invalid tool names are caught at Pydantic validation (Literal type)."""
|
||||
with pytest.raises(ValidationError, match="not_a_tool"):
|
||||
_make_input(tools=["not_a_tool"])
|
||||
|
||||
async def test_invalid_block_yields_error(self):
|
||||
mock_block_cls = MagicMock()
|
||||
mock_block_cls.return_value.name = "HTTP Request"
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value={"c069dc6b-c3ed-4c12-b6e5-d47361e64ce6": mock_block_cls},
|
||||
):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(blocks=["nonexistent_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "nonexistent_block" in outputs["error"]
|
||||
|
||||
async def test_empty_prompt_yields_error_before_permission_check(self):
|
||||
block = AutoPilotBlock()
|
||||
inp = _make_input(prompt=" ", tools=["run_block"])
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
assert "error" in outputs
|
||||
assert "Prompt cannot be empty" in outputs["error"]
|
||||
|
||||
async def test_valid_permissions_passed_to_execute(self):
|
||||
"""Permissions are forwarded to execute_copilot when valid."""
|
||||
block = AutoPilotBlock()
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_execute_copilot(self_inner, **kwargs):
|
||||
captured["permissions"] = kwargs.get("permissions")
|
||||
return (
|
||||
"ok",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"test-sid",
|
||||
{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
AutoPilotBlock, "create_session", new=AsyncMock(return_value="test-sid")
|
||||
), patch.object(AutoPilotBlock, "execute_copilot", new=fake_execute_copilot):
|
||||
inp = _make_input(tools=["run_block"], tools_exclude=False)
|
||||
outputs = await self._collect_outputs(block, inp)
|
||||
|
||||
assert "error" not in outputs
|
||||
perms = captured.get("permissions")
|
||||
assert isinstance(perms, CopilotPermissions)
|
||||
assert perms.tools == ["run_block"]
|
||||
assert perms.tools_exclude is False
|
||||
@@ -472,7 +472,7 @@ class AddToListBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
if input_data.entry is not None:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
@@ -73,7 +73,7 @@ class ReadDiscordMessagesBlock(Block):
|
||||
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
|
||||
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
|
||||
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
|
||||
description="Reads messages from a Discord channel using a bot token.",
|
||||
description="Reads new messages from a Discord channel using a bot token and triggers when a new message is posted",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={
|
||||
"continuous_read": False,
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -99,6 +100,8 @@ class SendEmailBlock(Block):
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
ALLOWED_SMTP_PORTS = {25, 465, 587, 2525}
|
||||
|
||||
@staticmethod
|
||||
def send_email(
|
||||
config: SMTPConfig,
|
||||
@@ -129,6 +132,17 @@ class SendEmailBlock(Block):
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# --- SSRF Protection ---
|
||||
smtp_port = input_data.config.smtp_port
|
||||
if smtp_port not in self.ALLOWED_SMTP_PORTS:
|
||||
yield "error", (
|
||||
f"SMTP port {smtp_port} is not allowed. "
|
||||
f"Allowed ports: {sorted(self.ALLOWED_SMTP_PORTS)}"
|
||||
)
|
||||
return
|
||||
|
||||
await resolve_and_check_blocked(input_data.config.smtp_server)
|
||||
|
||||
status = self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
@@ -180,7 +194,19 @@ class SendEmailBlock(Block):
|
||||
"was rejected by the server. "
|
||||
"Please verify your account is authorized to send emails."
|
||||
)
|
||||
except smtplib.SMTPConnectError:
|
||||
yield "error", (
|
||||
f"Cannot connect to SMTP server '{input_data.config.smtp_server}' "
|
||||
f"on port {input_data.config.smtp_port}."
|
||||
)
|
||||
except smtplib.SMTPServerDisconnected:
|
||||
yield "error", (
|
||||
f"SMTP server '{input_data.config.smtp_server}' "
|
||||
"disconnected unexpectedly."
|
||||
)
|
||||
except smtplib.SMTPDataError as e:
|
||||
yield "error", f"Email data rejected by server: {str(e)}"
|
||||
except ValueError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -34,17 +34,29 @@ TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class FluxKontextModelName(str, Enum):
|
||||
PRO = "Flux Kontext Pro"
|
||||
MAX = "Flux Kontext Max"
|
||||
class ImageEditorModel(str, Enum):
|
||||
FLUX_KONTEXT_PRO = "Flux Kontext Pro"
|
||||
FLUX_KONTEXT_MAX = "Flux Kontext Max"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
NANO_BANANA_2 = "Nano Banana 2"
|
||||
|
||||
@property
|
||||
def api_name(self) -> str:
|
||||
return f"black-forest-labs/flux-kontext-{self.name.lower()}"
|
||||
_map = {
|
||||
"FLUX_KONTEXT_PRO": "black-forest-labs/flux-kontext-pro",
|
||||
"FLUX_KONTEXT_MAX": "black-forest-labs/flux-kontext-max",
|
||||
"NANO_BANANA_PRO": "google/nano-banana-pro",
|
||||
"NANO_BANANA_2": "google/nano-banana-2",
|
||||
}
|
||||
return _map[self.name]
|
||||
|
||||
|
||||
# Keep old name as alias for backwards compatibility
|
||||
FluxKontextModelName = ImageEditorModel
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -69,7 +81,7 @@ class AIImageEditorBlock(Block):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Flux Kontext models",
|
||||
description="Replicate API key with permissions for Flux Kontext and Nano Banana models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text instruction describing the desired edit",
|
||||
@@ -87,14 +99,14 @@ class AIImageEditorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Random seed. Set for reproducible generation",
|
||||
description="Random seed. Set for reproducible generation (Flux Kontext only; ignored by Nano Banana models)",
|
||||
default=None,
|
||||
title="Seed",
|
||||
advanced=True,
|
||||
)
|
||||
model: FluxKontextModelName = SchemaField(
|
||||
model: ImageEditorModel = SchemaField(
|
||||
description="Model variant to use",
|
||||
default=FluxKontextModelName.PRO,
|
||||
default=ImageEditorModel.NANO_BANANA_2,
|
||||
title="Model",
|
||||
)
|
||||
|
||||
@@ -107,7 +119,7 @@ class AIImageEditorBlock(Block):
|
||||
super().__init__(
|
||||
id="3fd9c73d-4370-4925-a1ff-1b86b99fabfa",
|
||||
description=(
|
||||
"Edit images using BlackForest Labs' Flux Kontext models. Provide a prompt "
|
||||
"Edit images using Flux Kontext or Google Nano Banana models. Provide a prompt "
|
||||
"and optional reference image to generate a modified image."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
@@ -118,7 +130,7 @@ class AIImageEditorBlock(Block):
|
||||
"input_image": "data:image/png;base64,MQ==",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
"model": ImageEditorModel.NANO_BANANA_2,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -127,7 +139,9 @@ class AIImageEditorBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
"run_model": lambda *args, **kwargs: (
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -142,7 +156,7 @@ class AIImageEditorBlock(Block):
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
@@ -169,7 +183,7 @@ class AIImageEditorBlock(Block):
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
model: ImageEditorModel,
|
||||
prompt: str,
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
@@ -178,12 +192,29 @@ class AIImageEditorBlock(Block):
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
model_name = model.api_name
|
||||
|
||||
is_nano_banana = model in (
|
||||
ImageEditorModel.NANO_BANANA_PRO,
|
||||
ImageEditorModel.NANO_BANANA_2,
|
||||
)
|
||||
if is_nano_banana:
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high",
|
||||
}
|
||||
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
|
||||
if input_image_b64:
|
||||
input_params["image_input"] = [input_image_b64]
|
||||
else:
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
|
||||
@@ -11,7 +11,10 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
@@ -178,7 +181,8 @@ class FileOperation(StrEnum):
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
content: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
@@ -275,11 +279,11 @@ class GithubMultiFileCommitBlock(Block):
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str) -> str:
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
json={"content": content, "encoding": encoding},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
@@ -301,10 +305,19 @@ class GithubMultiFileCommitBlock(Block):
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
@@ -358,15 +371,36 @@ class GithubMultiFileCommitBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
input_data.files,
|
||||
resolved_files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
@@ -97,7 +98,11 @@ async def test_multi_file_commit_error_path():
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from abc import ABC
|
||||
from email import encoders
|
||||
from email.mime.base import MIMEBase
|
||||
@@ -8,7 +9,7 @@ from email.mime.text import MIMEText
|
||||
from email.policy import SMTP
|
||||
from email.utils import getaddresses, parseaddr
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
@@ -42,8 +43,52 @@ NO_WRAP_POLICY = SMTP.clone(max_line_length=0)
|
||||
|
||||
|
||||
def serialize_email_recipients(recipients: list[str]) -> str:
|
||||
"""Serialize recipients list to comma-separated string."""
|
||||
return ", ".join(recipients)
|
||||
"""Serialize recipients list to comma-separated string.
|
||||
|
||||
Strips leading/trailing whitespace from each address to keep MIME
|
||||
headers clean (mirrors the strip done in ``validate_email_recipients``).
|
||||
"""
|
||||
return ", ".join(addr.strip() for addr in recipients)
|
||||
|
||||
|
||||
# RFC 5322 simplified pattern: local@domain where domain has at least one dot
|
||||
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
|
||||
|
||||
|
||||
def validate_email_recipients(recipients: list[str], field_name: str = "to") -> None:
|
||||
"""Validate that all recipients are plausible email addresses.
|
||||
|
||||
Raises ``ValueError`` with a user-friendly message listing every
|
||||
invalid entry so the caller (or LLM) can correct them in one pass.
|
||||
"""
|
||||
invalid = [addr for addr in recipients if not _EMAIL_RE.match(addr.strip())]
|
||||
if invalid:
|
||||
formatted = ", ".join(f"'{a}'" for a in invalid)
|
||||
raise ValueError(
|
||||
f"Invalid email address(es) in '{field_name}': {formatted}. "
|
||||
f"Each entry must be a valid email address (e.g. user@example.com)."
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasRecipients(Protocol):
|
||||
to: list[str]
|
||||
cc: list[str]
|
||||
bcc: list[str]
|
||||
|
||||
|
||||
def validate_all_recipients(input_data: HasRecipients) -> None:
|
||||
"""Validate to/cc/bcc recipient fields on an input namespace.
|
||||
|
||||
Calls ``validate_email_recipients`` for ``to`` (required) and
|
||||
``cc``/``bcc`` (when non-empty), raising ``ValueError`` on the
|
||||
first field that contains an invalid address.
|
||||
"""
|
||||
validate_email_recipients(input_data.to, "to")
|
||||
if input_data.cc:
|
||||
validate_email_recipients(input_data.cc, "cc")
|
||||
if input_data.bcc:
|
||||
validate_email_recipients(input_data.bcc, "bcc")
|
||||
|
||||
|
||||
def _make_mime_text(
|
||||
@@ -100,14 +145,16 @@ async def create_mime_message(
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
message = MIMEMultipart()
|
||||
message["to"] = serialize_email_recipients(input_data.to)
|
||||
message["subject"] = input_data.subject
|
||||
|
||||
if input_data.cc:
|
||||
message["cc"] = ", ".join(input_data.cc)
|
||||
message["cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
message["bcc"] = ", ".join(input_data.bcc)
|
||||
message["bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
|
||||
# Use the new helper function with content_type if available
|
||||
content_type = getattr(input_data, "content_type", None)
|
||||
@@ -1167,13 +1214,15 @@ async def _build_reply_message(
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
@@ -1685,13 +1734,16 @@ To: {original_to}
|
||||
else:
|
||||
body = f"{forward_header}\n\n{original_body}"
|
||||
|
||||
# Validate all recipient lists before building the MIME message
|
||||
validate_all_recipients(input_data)
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
msg["To"] = serialize_email_recipients(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
msg["Cc"] = serialize_email_recipients(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Bcc"] = serialize_email_recipients(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
|
||||
# Add body with proper content type
|
||||
|
||||
@@ -28,9 +28,9 @@ class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
It takes in a value, name, and description.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
It outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -47,12 +47,6 @@ class AgentInputBlock(Block):
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
@@ -65,10 +59,7 @@ class AgentInputBlock(Block):
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = copy.deepcopy(self.get_field_schema("value"))
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
return copy.deepcopy(self.get_field_schema("value"))
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field for interface definition
|
||||
@@ -86,18 +77,16 @@ class AgentInputBlock(Block):
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"value": 42,
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"description": "Example numeric input.",
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
("result", 42),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
@@ -211,7 +200,7 @@ class AgentOutputBlock(Block):
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
yield "output", await formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -245,13 +234,11 @@ class AgentShortTextInputBlock(AgentInputBlock):
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -285,13 +272,11 @@ class AgentLongTextInputBlock(AgentInputBlock):
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -325,13 +310,11 @@ class AgentNumberInputBlock(AgentInputBlock):
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
@@ -501,6 +484,12 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
|
||||
@@ -33,6 +33,13 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.openai_responses import (
|
||||
convert_tools_to_responses_format,
|
||||
extract_responses_content,
|
||||
extract_responses_reasoning,
|
||||
extract_responses_tool_calls,
|
||||
extract_responses_usage,
|
||||
)
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
@@ -42,6 +49,9 @@ settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
# HTTP status codes for user-caused errors that should not be reported to Sentry.
|
||||
USER_ERROR_STATUS_CODES = (401, 403, 429)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -94,6 +104,18 @@ class LlmModelMeta(EnumMeta):
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "LlmModel | None":
|
||||
"""Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'."""
|
||||
if isinstance(value, str) and "/" in value:
|
||||
stripped = value.split("/", 1)[1]
|
||||
try:
|
||||
return cls(stripped)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
@@ -111,7 +133,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
@@ -277,9 +298,6 @@ MODEL_METADATA = {
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
@@ -706,6 +724,9 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_reasoning")
|
||||
return None
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
@@ -721,6 +742,9 @@ def extract_openai_reasoning(response) -> str | None:
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if not response.choices:
|
||||
logger.warning("LLM response has empty choices in extract_openai_tool_calls")
|
||||
return None
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
@@ -793,6 +817,19 @@ async def llm_call(
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Sanitize unpaired surrogates in message content to prevent
|
||||
# UnicodeEncodeError when httpx encodes the JSON request body.
|
||||
for msg in prompt:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
|
||||
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
@@ -801,36 +838,53 @@ async def llm_call(
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
tools_param = convert_tools_to_responses_format(tools) if tools else openai.omit
|
||||
|
||||
text_config = openai.omit
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
text_config = {"format": {"type": "json_object"}} # type: ignore
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
response = await oai_client.responses.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
input=prompt, # type: ignore[arg-type]
|
||||
tools=tools_param, # type: ignore[arg-type]
|
||||
max_output_tokens=max_tokens,
|
||||
parallel_tool_calls=get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
),
|
||||
text=text_config, # type: ignore[arg-type]
|
||||
store=False,
|
||||
)
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
raw_tool_calls = extract_responses_tool_calls(response)
|
||||
tool_calls = (
|
||||
[
|
||||
ToolContentBlock(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=ToolCall(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
)
|
||||
for tc in raw_tool_calls
|
||||
]
|
||||
if raw_tool_calls
|
||||
else None
|
||||
)
|
||||
reasoning = extract_responses_reasoning(response)
|
||||
content = extract_responses_content(response)
|
||||
prompt_tokens, completion_tokens = extract_responses_usage(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
raw_response=response,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
response=content,
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
@@ -858,65 +912,60 @@ async def llm_call(
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=credentials.api_key.get_secret_value()
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else getattr(resp.content[0], "text", "")
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
@@ -929,6 +978,8 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("Groq returned empty choices in response")
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -988,12 +1039,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
raise ValueError(f"OpenRouter returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1030,12 +1077,8 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"Llama API error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
raise ValueError(f"Llama API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
@@ -1065,6 +1108,8 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if not completion.choices:
|
||||
raise ValueError("AI/ML API returned empty choices in response")
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=completion.choices[0].message,
|
||||
@@ -1101,6 +1146,9 @@ async def llm_call(
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError(f"v0 API returned empty choices: {response}")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
@@ -1276,8 +1324,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = await fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = await fmt.format_string(
|
||||
input_data.sys_prompt, values
|
||||
)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
@@ -1427,7 +1477,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
is_user_error = (
|
||||
isinstance(e, (anthropic.APIStatusError, openai.APIStatusError))
|
||||
and e.status_code in USER_ERROR_STATUS_CODES
|
||||
)
|
||||
if is_user_error:
|
||||
logger.warning(f"Error calling LLM: {e}")
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
break
|
||||
else:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
"maximum context length" in str(e).lower()
|
||||
or "token limit" in str(e).lower()
|
||||
@@ -1957,6 +2016,19 @@ class AIConversationBlock(AIBlockBase):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
has_messages = any(
|
||||
isinstance(m, dict)
|
||||
and isinstance(m.get("content"), str)
|
||||
and bool(m["content"].strip())
|
||||
for m in (input_data.messages or [])
|
||||
)
|
||||
has_prompt = bool(input_data.prompt and input_data.prompt.strip())
|
||||
if not has_messages and not has_prompt:
|
||||
raise ValueError(
|
||||
"Cannot call LLM with no messages and no prompt. "
|
||||
"Provide at least one message or a non-empty prompt."
|
||||
)
|
||||
|
||||
response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt=input_data.prompt,
|
||||
|
||||
2033
autogpt_platform/backend/backend/blocks/orchestrator.py
Normal file
2033
autogpt_platform/backend/backend/blocks/orchestrator.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
# Monkey patch Stagehands to prevent signal handling in worker threads
|
||||
import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
from stagehand import AsyncStagehand
|
||||
from stagehand.types.session_act_params import Options as ActOptions
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -28,46 +23,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings("ignore", module="litellm.llms.custom_httpx")
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
def safe_register_signal_handlers(self):
|
||||
"""Only register signal handlers in the main thread"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_register_signal_handlers(self)
|
||||
else:
|
||||
# Skip signal handling in worker threads
|
||||
pass
|
||||
|
||||
|
||||
# Replace the method
|
||||
stagehand.main.Stagehand._register_signal_handlers = safe_register_signal_handlers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal_handling():
|
||||
"""Context manager to temporarily disable signal handling"""
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# In worker threads, temporarily replace signal.signal with a no-op
|
||||
original_signal = signal.signal
|
||||
|
||||
def noop_signal(*args, **kwargs):
|
||||
pass
|
||||
|
||||
signal.signal = noop_signal
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal = original_signal
|
||||
else:
|
||||
# In main thread, don't modify anything
|
||||
yield
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -148,13 +103,10 @@ class StagehandObserveBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -185,32 +137,28 @@ class StagehandObserveBlock(Block):
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
|
||||
observe_results = await page.observe(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
for result in observe_results:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
observe_response = await session.observe(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
for result in observe_response.data.result:
|
||||
yield "selector", result.selector
|
||||
yield "description", result.description
|
||||
yield "method", result.method
|
||||
yield "arguments", result.arguments
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandActBlock(Block):
|
||||
@@ -242,24 +190,22 @@ class StagehandActBlock(Block):
|
||||
description="Variables to use in the action. Variables contains data you want the action to use.",
|
||||
default_factory=dict,
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
)
|
||||
timeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM ready. Extended timeout for slow-loading forms",
|
||||
default=60000,
|
||||
timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms for each action.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the action was completed successfully"
|
||||
)
|
||||
message: str = SchemaField(description="Details about the action’s execution.")
|
||||
message: str = SchemaField(description="Details about the action's execution.")
|
||||
action: str = SchemaField(description="Action performed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -282,32 +228,33 @@ class StagehandActBlock(Block):
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
for action in input_data.action:
|
||||
action_results = await page.act(
|
||||
action,
|
||||
variables=input_data.variables,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
timeoutMs=input_data.timeoutMs,
|
||||
)
|
||||
yield "success", action_results.success
|
||||
yield "message", action_results.message
|
||||
yield "action", action_results.action
|
||||
for action in input_data.action:
|
||||
act_options = ActOptions(
|
||||
variables={k: v for k, v in input_data.variables.items()},
|
||||
timeout=input_data.timeout_ms,
|
||||
)
|
||||
act_response = await session.act(
|
||||
input=action,
|
||||
options=act_options,
|
||||
)
|
||||
result = act_response.data.result
|
||||
yield "success", result.success
|
||||
yield "message", result.message
|
||||
yield "action", result.action_description
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
|
||||
class StagehandExtractBlock(Block):
|
||||
@@ -335,13 +282,10 @@ class StagehandExtractBlock(Block):
|
||||
instruction: str = SchemaField(
|
||||
description="Natural language description of elements or actions to discover.",
|
||||
)
|
||||
iframes: bool = SchemaField(
|
||||
description="Whether to search within iframes. If True, Stagehand will search for actions within iframes.",
|
||||
default=True,
|
||||
)
|
||||
domSettleTimeoutMs: int = SchemaField(
|
||||
description="Timeout in milliseconds for DOM settlement.Wait longer for dynamic content",
|
||||
default=45000,
|
||||
dom_settle_timeout_ms: int = SchemaField(
|
||||
description="Timeout in ms to wait for the DOM to settle after navigation.",
|
||||
default=30000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
@@ -367,24 +311,21 @@ class StagehandExtractBlock(Block):
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
project_id=input_data.browserbase_project_id,
|
||||
async with AsyncStagehand(
|
||||
browserbase_api_key=stagehand_credentials.api_key.get_secret_value(),
|
||||
browserbase_project_id=input_data.browserbase_project_id,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
) as client:
|
||||
session = await client.sessions.start(
|
||||
model_name=input_data.model.provider_name,
|
||||
model_api_key=model_credentials.api_key.get_secret_value(),
|
||||
dom_settle_timeout_ms=input_data.dom_settle_timeout_ms,
|
||||
)
|
||||
try:
|
||||
await session.navigate(url=input_data.url)
|
||||
|
||||
await stagehand.init()
|
||||
|
||||
page = stagehand.page
|
||||
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
|
||||
await page.goto(input_data.url)
|
||||
extraction = await page.extract(
|
||||
input_data.instruction,
|
||||
iframes=input_data.iframes,
|
||||
domSettleTimeoutMs=input_data.domSettleTimeoutMs,
|
||||
)
|
||||
yield "extraction", str(extraction.model_dump()["extraction"])
|
||||
extract_response = await session.extract(
|
||||
instruction=input_data.instruction,
|
||||
)
|
||||
yield "extraction", str(extract_response.data.result)
|
||||
finally:
|
||||
await session.end()
|
||||
|
||||
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
223
autogpt_platform/backend/backend/blocks/test/test_autopilot.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.autopilot import (
|
||||
AUTOPILOT_BLOCK_ID,
|
||||
AutoPilotBlock,
|
||||
_autopilot_recursion_depth,
|
||||
_autopilot_recursion_limit,
|
||||
_check_recursion,
|
||||
_reset_recursion,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
|
||||
def _make_context(user_id: str = "test-user-123") -> ExecutionContext:
|
||||
"""Helper to build an ExecutionContext for tests."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id="graph-1",
|
||||
graph_exec_id="gexec-1",
|
||||
graph_version=1,
|
||||
node_id="node-1",
|
||||
node_exec_id="nexec-1",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursion guard unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRecursion:
|
||||
"""Unit tests for _check_recursion / _reset_recursion."""
|
||||
|
||||
def test_first_call_increments_depth(self):
|
||||
tokens = _check_recursion(3)
|
||||
try:
|
||||
assert _autopilot_recursion_depth.get() == 1
|
||||
assert _autopilot_recursion_limit.get() == 3
|
||||
finally:
|
||||
_reset_recursion(tokens)
|
||||
|
||||
def test_reset_restores_previous_values(self):
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
tokens = _check_recursion(5)
|
||||
_reset_recursion(tokens)
|
||||
assert _autopilot_recursion_depth.get() == 0
|
||||
assert _autopilot_recursion_limit.get() is None
|
||||
|
||||
def test_exceeding_limit_raises(self):
|
||||
t1 = _check_recursion(2)
|
||||
try:
|
||||
t2 = _check_recursion(2)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(2)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_nested_calls_respect_inherited_limit(self):
|
||||
"""Inner call with higher max_depth still respects outer limit."""
|
||||
t1 = _check_recursion(2) # sets limit=2
|
||||
try:
|
||||
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
|
||||
try:
|
||||
# depth is now 2, limit is min(10, 2) = 2 → should raise
|
||||
with pytest.raises(RuntimeError, match="recursion depth limit"):
|
||||
_check_recursion(10)
|
||||
finally:
|
||||
_reset_recursion(t2)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
def test_limit_of_one_blocks_immediately_on_second_call(self):
|
||||
t1 = _check_recursion(1)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
_check_recursion(1)
|
||||
finally:
|
||||
_reset_recursion(t1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AutoPilotBlock.run() validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
"""Tests for input validation in AutoPilotBlock.run()."""
|
||||
|
||||
@pytest.fixture
|
||||
def block(self):
|
||||
return AutoPilotBlock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_yields_error(self, block):
|
||||
block.Input # ensure schema is accessible
|
||||
input_data = block.Input(prompt=" ", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert outputs.get("error") == "Prompt cannot be empty."
|
||||
assert "response" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_id_yields_error(self, block):
|
||||
input_data = block.Input(prompt="hello", max_recursion_depth=3)
|
||||
ctx = _make_context(user_id="")
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
assert "authenticated user" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run_yields_all_outputs(self, block):
|
||||
"""With execute_copilot mocked, run() should yield all 5 success outputs."""
|
||||
mock_result = (
|
||||
"Hello world",
|
||||
[],
|
||||
'[{"role":"user","content":"hi"}]',
|
||||
"sess-abc",
|
||||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock(return_value="sess-abc")
|
||||
|
||||
input_data = block.Input(prompt="hi", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["response"] == "Hello world"
|
||||
assert outputs["tool_calls"] == []
|
||||
assert outputs["session_id"] == "sess-abc"
|
||||
assert outputs["token_usage"]["total_tokens"] == 15
|
||||
assert "error" not in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_yields_error(self, block):
|
||||
"""On unexpected failure, run() should yield an error output."""
|
||||
block.execute_copilot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
block.create_session = AsyncMock(return_value="sess-fail")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-fail"
|
||||
assert "boom" in outputs.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_error_yields_error_and_reraises(self, block):
|
||||
"""CancelledError should yield error, then re-raise."""
|
||||
block.execute_copilot = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
block.create_session = AsyncMock(return_value="sess-cancel")
|
||||
|
||||
input_data = block.Input(prompt="do something", max_recursion_depth=3)
|
||||
ctx = _make_context()
|
||||
outputs = {}
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
async for name, value in block.run(input_data, execution_context=ctx):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["session_id"] == "sess-cancel"
|
||||
assert "cancelled" in outputs.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_session_id_skips_create(self, block):
|
||||
"""When session_id is provided, create_session should not be called."""
|
||||
mock_result = (
|
||||
"ok",
|
||||
[],
|
||||
"[]",
|
||||
"existing-sid",
|
||||
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
block.execute_copilot = AsyncMock(return_value=mock_result)
|
||||
block.create_session = AsyncMock()
|
||||
|
||||
input_data = block.Input(
|
||||
prompt="test", session_id="existing-sid", max_recursion_depth=3
|
||||
)
|
||||
ctx = _make_context()
|
||||
async for _ in block.run(input_data, execution_context=ctx):
|
||||
pass
|
||||
|
||||
block.create_session.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block registration / ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockRegistration:
|
||||
def test_block_id_matches_constant(self):
|
||||
block = AutoPilotBlock()
|
||||
assert block.id == AUTOPILOT_BLOCK_ID
|
||||
|
||||
def test_max_recursion_depth_has_upper_bound(self):
|
||||
"""Schema should enforce le=10."""
|
||||
schema = AutoPilotBlock.Input.model_json_schema()
|
||||
max_rec = schema["properties"]["max_recursion_depth"]
|
||||
assert (
|
||||
max_rec.get("maximum") == 10 or max_rec.get("exclusiveMaximum", 999) <= 11
|
||||
)
|
||||
|
||||
def test_output_schema_has_no_duplicate_error_field(self):
|
||||
"""Output should inherit error from BlockSchemaOutput, not redefine it."""
|
||||
# The field should exist (inherited) but there should be no explicit
|
||||
# redefinition. We verify by checking the class __annotations__ directly.
|
||||
assert "error" not in AutoPilotBlock.Output.__annotations__
|
||||
@@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.blocks._base import Block, BlockSchemaInput
|
||||
from backend.blocks.io import AgentDropdownInputBlock, AgentInputBlock
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
@@ -279,3 +281,66 @@ class TestAutoCredentialsFieldsValidation:
|
||||
assert "Duplicate auto_credentials kwarg_name 'credentials'" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
"""Verify AgentInputBlock.Input.model_construct tolerates extra placeholder_values
|
||||
for backward compatibility with existing agent JSON."""
|
||||
legacy_data = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
instance = AgentInputBlock.Input.model_construct(**legacy_data)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
"enum" not in schema
|
||||
), "AgentInputBlock should not produce enum from legacy placeholder_values"
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with legacy placeholder_values
|
||||
on AgentInputBlock — verifies no enum leaks through the graph loading path."""
|
||||
legacy_input_default = {
|
||||
"name": "url",
|
||||
"value": "",
|
||||
"description": "Enter a URL",
|
||||
"placeholder_values": ["https://example.com"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentInputBlock.Input, legacy_input_default),
|
||||
)
|
||||
url_props = result["properties"]["url"]
|
||||
assert (
|
||||
"enum" not in url_props
|
||||
), "Graph schema should not contain enum from AgentInputBlock placeholder_values"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user