Compare commits

..

6 Commits

Author SHA1 Message Date
Bently
f46868989a Merge branch 'dev' into docs/deployment-env-variables 2026-03-05 19:37:45 +00:00
Bently
2cce5c3f1d Merge branch 'dev' into docs/deployment-env-variables 2026-02-24 14:38:55 +00:00
Bently
5c01eb4fc8 Merge branch 'dev' into docs/deployment-env-variables 2026-02-23 16:43:20 +00:00
Bently
2d7431bde6 Merge branch 'dev' into docs/deployment-env-variables 2026-02-19 17:49:09 +00:00
Bentlybro
e934df3c0c fix: address code review feedback
- Add 'text' language identifier to code blocks (MD040)
- Add VAULT_ENC_KEY generation command (openssl rand -hex 16)
- Fix DB_HOST default to 'localhost' (not 'db')
- Add info box clarifying port numbers are internal Docker ports
- Update OAuth callback URL to not include port by default
- Clarify Docker service names are internal container DNS
2026-02-16 12:10:09 +00:00
Bentlybro
8d557d33e1 docs: add deployment environment variables guide
Closes #10961, Closes OPEN-2715

Documents all environment variables that must be configured when deploying
AutoGPT to a new server:

- Quick reference table of critical URLs that must change
- Configuration file locations and loading order
- Security keys that must be regenerated (with generation commands)
- Database, Redis, RabbitMQ configuration
- Default ports for all services
- OAuth callback URLs for all supported providers
- Full deployment checklist
- Docker vs external services guidance
2026-02-16 11:59:34 +00:00
2592 changed files with 742540 additions and 60582 deletions

View File

@@ -1,200 +0,0 @@
---
name: pr-address
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
user-invocable: true
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Address
## Find the PR
```bash
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
gh pr view {N}
```
## Fetch comments (all sources)
### 1. Inline review threads — GraphQL (primary source of actionable items)
Use GraphQL to fetch inline threads. It natively exposes `isResolved`, returns threads already grouped with all replies, and paginates via cursor — no manual thread reconstruction needed.
```bash
gh api graphql -f query='
{
repository(owner: "Significant-Gravitas", name: "AutoGPT") {
pullRequest(number: {N}) {
reviewThreads(first: 100) {
pageInfo { hasNextPage endCursor }
nodes {
id
isResolved
path
comments(last: 1) {
nodes { databaseId body author { login } createdAt }
}
}
}
}
}
}'
```
If `pageInfo.hasNextPage` is true, fetch subsequent pages by adding `after: "<endCursor>"` to `reviewThreads(first: 100, after: "...")` and repeat until `hasNextPage` is false.
**Filter to unresolved threads only** — skip any thread where `isResolved: true`. `comments(last: 1)` returns the most recent comment in the thread — act on that; it reflects the reviewer's final ask. Use the thread `id` (Relay global ID) to track threads across polls.
### 2. Top-level reviews — REST (MUST paginate)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
**CRITICAL — always `--paginate`.** Reviews default to 30 per page. PRs can have 80170+ reviews (mostly empty resolution events). Without pagination you miss reviews past position 30 — including `autogpt-reviewer`'s structured review which is typically posted after several CI runs and sits well beyond the first page.
Two things to extract:
- **Overall state**: look for `CHANGES_REQUESTED` or `APPROVED` reviews.
- **Actionable feedback**: non-empty bodies only. Empty-body reviews are thread-resolution events — they indicate progress but have no feedback to act on.
**Where each reviewer posts:**
- `autogpt-reviewer` — posts detailed structured reviews ("Blockers", "Should Fix", "Nice to Have") as **top-level reviews**. Not present on every PR. Address ALL items.
- `sentry[bot]` — posts bug predictions as **inline threads**. Fix real bugs, explain false positives.
- `coderabbitai[bot]` — posts summaries as **top-level reviews** AND actionable items as **inline threads**. Address actionable items.
- Human reviewers — can post in any source. Address ALL non-empty feedback.
### 3. PR conversation comments — REST
```bash
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
Mostly contains: bot summaries (`coderabbitai[bot]`), CI/conflict detection (`github-actions[bot]`), and author status updates. Scan for non-empty messages from non-bot human reviewers that aren't the PR author — those are the ones that need a response.
## For each unaddressed comment
Address comments **one at a time**: fix → commit → push → inline reply → next.
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
2. Commit and push the fix
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
| Comment type | How to reply |
|---|---|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="🤖 Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="🤖 Fixed in <commit-sha>: <description>"` |
## Format and commit
After fixing, format the changed code:
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
If API routes changed, regenerate the frontend client:
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
trap "kill $REST_PID 2>/dev/null" EXIT
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID 2>/dev/null; trap - EXIT
```
Never manually edit files in `src/app/api/__generated__/`.
Then commit and **push immediately** — never batch commits without pushing.
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
## The loop
```text
address comments → format → commit → push
→ wait for CI (while addressing new comments) → fix failures → push
→ re-check comments after CI settles
→ repeat until: all comments addressed AND CI green AND no new comments arriving
```
### Polling for CI + new comments
After pushing, poll for **both** CI status and new comments in a single loop. Do not use `gh pr checks --watch` — it blocks the tool and prevents reacting to new comments while CI is running.
> **Note:** `gh pr checks --watch --fail-fast` is tempting but it blocks the entire Bash tool call, meaning the agent cannot check for or address new comments until CI fully completes. Always poll manually instead.
**Polling loop — repeat every 30 seconds:**
1. Check CI status:
```bash
gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,name,link
```
Parse the results: if every check has `bucket` of `"pass"` or `"skipping"`, CI is green. If any has `"fail"`, CI has failed. Otherwise CI is still pending.
2. Check for merge conflicts:
```bash
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json mergeable --jq '.mergeable'
```
If the result is `"CONFLICTING"`, the PR has a merge conflict — see "Resolving merge conflicts" below. If `"UNKNOWN"`, GitHub is still computing mergeability — wait and re-check next poll.
3. Check for new/changed comments (all three sources):
**Inline threads** — re-run the GraphQL query from "Fetch comments". For each unresolved thread, record `{thread_id, last_comment_databaseId}` as your baseline. On each poll, action is needed if:
- A new thread `id` appears that wasn't in the baseline (new thread), OR
- An existing thread's `last_comment_databaseId` has changed (new reply on existing thread)
**Conversation comments:**
```bash
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments --paginate
```
Compare total count and newest `id` against baseline. Filter to non-empty, non-bot, non-author-update messages.
**Top-level reviews:**
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate
```
Watch for new non-empty reviews (`CHANGES_REQUESTED` or `COMMENTED` with body). Compare total count and newest `id` against baseline.
4. **React in this precedence order (first match wins):**
| What happened | Action |
|---|---|
| Merge conflict detected | See "Resolving merge conflicts" below. |
| Mergeability is `UNKNOWN` | GitHub is still computing mergeability. Sleep 30 seconds, then restart polling from the top. |
| New comments detected | Address them (fix → commit → push → reply). After pushing, re-fetch all comments to update your baseline, then restart this polling loop from the top (new commits invalidate CI status). |
| CI failed (bucket == "fail") | Get failed check links: `gh pr checks {N} --repo Significant-Gravitas/AutoGPT --json bucket,link --jq '.[] \| select(.bucket == "fail") \| .link'`. Extract run ID from link (format: `.../actions/runs/<run-id>/job/...`), read logs with `gh run view <run-id> --repo Significant-Gravitas/AutoGPT --log-failed`. Fix → commit → push → restart polling. |
| CI green + no new comments | **Do not exit immediately.** Bots (coderabbitai, sentry) often post reviews shortly after CI settles. Continue polling for **2 more cycles (60s)** after CI goes green. Only exit after 2 consecutive green+quiet polls. |
| CI pending + no new comments | Sleep 30 seconds, then poll again. |
**The loop ends when:** CI fully green + all comments addressed + **2 consecutive polls with no new comments after CI settled.**
### Resolving merge conflicts
1. Identify the PR's target branch and remote:
```bash
gh pr view {N} --repo Significant-Gravitas/AutoGPT --json baseRefName --jq '.baseRefName'
git remote -v # find the remote pointing to Significant-Gravitas/AutoGPT (typically 'upstream' in forks, 'origin' for direct contributors)
```
2. Pull the latest base branch with a 3-way merge:
```bash
git pull {base-remote} {base-branch} --no-rebase
```
3. Resolve conflicting files, then verify no conflict markers remain:
```bash
if grep -R -n -E '^(<<<<<<<|=======|>>>>>>>)' <conflicted-files>; then
echo "Unresolved conflict markers found — resolve before proceeding."
exit 1
fi
```
4. Stage and push:
```bash
git add <conflicted-files>
git commit -m "Resolve merge conflicts with {base-branch}"
git push
```
5. Restart the polling loop from the top — new commits reset CI status.

View File

@@ -1,74 +0,0 @@
---
name: pr-review
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
user-invocable: true
args: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review
## Find the PR
```bash
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
gh pr view {N}
```
## Read the diff
```bash
gh pr diff {N}
```
## Fetch existing review comments
Before posting anything, fetch existing inline comments to avoid duplicates:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
```
## What to check
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
## Output format
Every comment **must** be prefixed with `🤖` and a criticality badge:
| Tier | Badge | Meaning |
|---|---|---|
| Blocker | `🔴 **Blocker**` | Must fix before merge |
| Should Fix | `🟠 **Should Fix**` | Important improvement |
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
| Nit | `🔵 **Nit**` | Style / wording |
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
## Post inline comments
For each finding, post an inline comment on the PR (do not just write a local report):
```bash
# Get the latest commit SHA for the PR
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
# Post an inline comment on a specific file/line
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
-f body="🤖 🔴 **Blocker**: <description>" \
-f commit_id="$COMMIT_SHA" \
-f path="<file path>" \
-F line=<line number>
```

View File

@@ -1,85 +0,0 @@
---
name: worktree
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
metadata:
author: autogpt-team
version: "3.0.0"
---
# Worktree Setup
## Create the worktree
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
```bash
ROOT=$(git rev-parse --show-toplevel)
PARENT=$(dirname "$ROOT")
# From an existing branch
git worktree add "$PARENT/<NAME>" <branch-name>
# From a new branch off dev
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
```
## Copy environment files
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
```bash
ROOT=$(git rev-parse --show-toplevel)
TARGET="$(dirname "$ROOT")/<NAME>"
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
if [ -f "$ROOT/$envpath/.env" ]; then
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$ROOT/$envpath/.env.default" ]; then
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
fi
done
```
## Install dependencies
```bash
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
cd "$TARGET/autogpt_platform/frontend" && pnpm install
```
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
## Running the app (optional)
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
```bash
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd "$TARGET/autogpt_platform/backend" && poetry run app
```
## CoPilot testing
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
## Cleanup
```bash
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
```
## Alternative: Branchlet (optional)
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
```bash
branchlet create -n <name> -s <source-branch> -b <new-branch>
```

View File

@@ -5,14 +5,12 @@ on:
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
pull_request:
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
merge_group:
@@ -27,91 +25,10 @@ defaults:
working-directory: autogpt_platform/backend
jobs:
lint:
permissions:
contents: read
timeout-minutes: 10
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Set up Python dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-py3.12-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
- name: Install Python dependencies
run: poetry install
- name: Run Linters
run: poetry run lint --skip-pyright
env:
CI: true
PLAIN_OUTPUT: True
type-check:
permissions:
contents: read
timeout-minutes: 10
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set up Python dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
- name: Install Python dependencies
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate && poetry run gen-prisma-stub
- name: Run Pyright
run: poetry run pyright --pythonversion ${{ matrix.python-version }}
env:
CI: true
PLAIN_OUTPUT: True
test:
permissions:
contents: read
timeout-minutes: 15
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
@@ -179,9 +96,9 @@ jobs:
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
- name: Install Poetry (Unix)
run: |
# Extract Poetry version from backend/poetry.lock
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
@@ -239,22 +156,22 @@ jobs:
echo "Waiting for ClamAV daemon to start..."
max_attempts=60
attempt=0
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
sleep 5
attempt=$((attempt+1))
done
if [ $attempt -eq $max_attempts ]; then
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
echo "Checking ClamAV service logs..."
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
fi
echo "ClamAV is ready!"
# Verify ClamAV is responsive
echo "Testing ClamAV connection..."
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
@@ -269,13 +186,18 @@ jobs:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- name: Run pytest
- id: lint
name: Run Linter
run: poetry run lint
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
else
poetry run pytest -s -vv
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
@@ -287,12 +209,6 @@ jobs:
REDIS_PORT: "6379"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}
env:
CI: true
PLAIN_OUTPUT: True
@@ -306,3 +222,9 @@ jobs:
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}

View File

@@ -120,6 +120,175 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs
integration_test:
runs-on: ubuntu-latest
needs: setup

View File

@@ -1,18 +1,14 @@
name: AutoGPT Platform - Full-stack CI
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
merge_group:
@@ -28,28 +24,42 @@ defaults:
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Enable corepack
run: corepack enable
- name: Set up Node
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install dependencies to populate cache
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
check-api-types:
name: check API types
runs-on: ubuntu-latest
types:
runs-on: big-boi
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
@@ -57,256 +67,70 @@ jobs:
with:
submodules: recursive
# ------------------------ Backend setup ------------------------
- name: Set up Backend - Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Set up Backend - Install Poetry
working-directory: autogpt_platform/backend
run: |
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Installing Poetry version ${POETRY_VERSION}"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
- name: Set up Backend - Set up dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Set up Backend - Install dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Set up Backend - Generate Prisma client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
- name: Set up Frontend - Export OpenAPI schema from Backend
working-directory: autogpt_platform/backend
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
# ------------------------ Frontend setup ------------------------
- name: Set up Frontend - Enable corepack
run: corepack enable
- name: Set up Frontend - Set up Node
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up Frontend - Install dependencies
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
- name: Restore dependencies cache
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up Frontend - Format OpenAPI schema
id: format-schema
run: pnpm prettier --write ./src/app/api/openapi.json
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after exporting the API schema."
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "\nIn the backend directory:"
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
echo "\nIn the frontend directory:"
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
echo "3. Run 'pnpm generate:api'"
echo "4. Run 'pnpm types'"
echo "5. Fix any TypeScript errors that may have been introduced"
echo "6. Commit and push your changes"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Set up Frontend - Generate API client
id: generate-api-client
run: pnpm orval --config ./orval.config.ts
# Continue with type generation & check even if there are schema changes
if: success() || (steps.format-schema.outcome == 'success')
- name: Check for TypeScript errors
- name: Run Typescript checks
run: pnpm types
if: success() || (steps.generate-api-client.outcome == 'success')
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: autogpt_platform/frontend/playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: autogpt_platform/frontend/test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs

View File

@@ -56,36 +56,13 @@ AutoGPT Platform is a monorepo containing:
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Always use `--body-file` to pass PR body — avoids shell interpretation of backticks and special characters:
```bash
PR_BODY=$(mktemp)
cat > "$PR_BODY" << 'PREOF'
## Summary
- use `backticks` freely here
PREOF
gh pr create --title "..." --body-file "$PR_BODY" --base dev
rm "$PR_BODY"
```
- Run the github pre-commit hooks to ensure code quality.
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, follow a test-first approach:
1. **Write a failing test first** — create a test that reproduces the bug or validates the new behavior, marked with `@pytest.mark.xfail` (backend) or `.fixme` (Playwright). Run it to confirm it fails for the right reason.
2. **Implement the fix/feature** — write the minimal code to make the test pass.
3. **Remove the xfail marker** — once the test passes, remove the `xfail`/`.fixme` annotation and run the full test suite to confirm nothing else broke.
This ensures every change is covered by a test and that the test actually validates the intended behavior.
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews --paginate` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments --paginate` — inline review comments (always paginate to avoid missing comments beyond page 1)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits

View File

@@ -1,40 +0,0 @@
-- =============================================================
-- View: analytics.auth_activities
-- Looker source alias: ds49 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Tracks authentication events (login, logout, SSO, password
-- reset, etc.) from Supabase's internal audit log.
-- Useful for monitoring sign-in patterns and detecting anomalies.
--
-- SOURCE TABLES
-- auth.audit_log_entries — Supabase internal auth event log
--
-- OUTPUT COLUMNS
-- created_at TIMESTAMPTZ When the auth event occurred
-- actor_id TEXT User ID who triggered the event
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
--
-- WINDOW
-- Rolling 90 days from current date
--
-- EXAMPLE QUERIES
-- -- Daily login counts
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
-- FROM analytics.auth_activities
-- WHERE action = 'login'
-- GROUP BY 1 ORDER BY 1;
--
-- -- SSO vs password login breakdown
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
-- WHERE action = 'login' GROUP BY 1;
-- =============================================================
SELECT
created_at,
payload->>'actor_id' AS actor_id,
payload->>'actor_via_sso' AS actor_via_sso,
payload->>'action' AS action
FROM auth.audit_log_entries
WHERE created_at >= NOW() - INTERVAL '90 days'

View File

@@ -1,105 +0,0 @@
-- =============================================================
-- View: analytics.graph_execution
-- Looker source alias: ds16 | Charts: 21
-- =============================================================
-- DESCRIPTION
-- One row per agent graph execution (last 90 days).
-- Unpacks the JSONB stats column into individual numeric columns
-- and normalises the executionStatus — runs that failed due to
-- insufficient credits are reclassified as 'NO_CREDITS' for
-- easier filtering. Error messages are scrubbed of IDs and URLs
-- to allow safe grouping.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
-- platform.AgentGraph — Agent graph metadata (for name)
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
--
-- OUTPUT COLUMNS
-- id TEXT Execution UUID
-- agentGraphId TEXT Agent graph UUID
-- agentGraphVersion INT Graph version number
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
-- createdAt TIMESTAMPTZ When the execution was queued
-- updatedAt TIMESTAMPTZ Last status update time
-- userId TEXT Owner user UUID
-- agentGraphName TEXT Human-readable agent name
-- cputime DECIMAL Total CPU seconds consumed
-- walltime DECIMAL Total wall-clock seconds
-- node_count DECIMAL Number of nodes in the graph
-- nodes_cputime DECIMAL CPU time across all nodes
-- nodes_walltime DECIMAL Wall time across all nodes
-- execution_cost DECIMAL Credit cost of this execution
-- correctness_score FLOAT AI correctness score (if available)
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Daily execution counts by status
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
-- FROM analytics.graph_execution
-- GROUP BY 1, 2 ORDER BY 1;
--
-- -- Average cost per execution by agent
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'COMPLETED'
-- GROUP BY 1 ORDER BY avg_cost DESC;
--
-- -- Top error messages
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
-- =============================================================
SELECT
ge."id" AS id,
ge."agentGraphId" AS agentGraphId,
ge."agentGraphVersion" AS agentGraphVersion,
CASE
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
AND (
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
)
THEN 'NO_CREDITS'
ELSE CAST(ge."executionStatus" AS TEXT)
END AS executionStatus,
ge."createdAt" AS createdAt,
ge."updatedAt" AS updatedAt,
ge."userId" AS userId,
g."name" AS agentGraphName,
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage
FROM platform."AgentGraphExecution" ge
LEFT JOIN platform."AgentGraph" g
ON ge."agentGraphId" = g."id"
AND ge."agentGraphVersion" = g."version"
LEFT JOIN (
SELECT DISTINCT ON ("userId", "agentGraphId")
"userId", "agentGraphId",
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
FROM platform."LibraryAgent"
WHERE "isDeleted" = FALSE
AND "isArchived" = FALSE
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,101 +0,0 @@
-- =============================================================
-- View: analytics.node_block_execution
-- Looker source alias: ds14 | Charts: 11
-- =============================================================
-- DESCRIPTION
-- One row per node (block) execution (last 90 days).
-- Unpacks stats JSONB and joins to identify which block type
-- was run. For failed nodes, joins the error output and
-- scrubs it for safe grouping.
--
-- SOURCE TABLES
-- platform.AgentNodeExecution — Node execution records
-- platform.AgentNode — Node → block mapping
-- platform.AgentBlock — Block name/ID
-- platform.AgentNodeExecutionInputOutput — Error output values
--
-- OUTPUT COLUMNS
-- id TEXT Node execution UUID
-- agentGraphExecutionId TEXT Parent graph execution UUID
-- agentNodeId TEXT Node UUID within the graph
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
-- addedTime TIMESTAMPTZ When the node was queued
-- queuedTime TIMESTAMPTZ When it entered the queue
-- startedTime TIMESTAMPTZ When execution started
-- endedTime TIMESTAMPTZ When execution finished
-- inputSize BIGINT Input payload size in bytes
-- outputSize BIGINT Output payload size in bytes
-- walltime NUMERIC Wall-clock seconds for this node
-- cputime NUMERIC CPU seconds for this node
-- llmRetryCount INT Number of LLM retries
-- llmCallCount INT Number of LLM API calls made
-- inputTokenCount BIGINT LLM input tokens consumed
-- outputTokenCount BIGINT LLM output tokens produced
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
-- blockId TEXT Block UUID
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
-- errorMessage TEXT Raw error output (only set when FAILED)
--
-- WINDOW
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Most-used blocks by execution count
-- SELECT "blockName", COUNT(*) AS executions,
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
-- FROM analytics.node_block_execution
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
--
-- -- Average LLM token usage per block
-- SELECT "blockName",
-- AVG("inputTokenCount") AS avg_input_tokens,
-- AVG("outputTokenCount") AS avg_output_tokens
-- FROM analytics.node_block_execution
-- WHERE "llmCallCount" > 0
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
--
-- -- Top failure reasons
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
-- FROM analytics.node_block_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
-- =============================================================
SELECT
ne."id" AS id,
ne."agentGraphExecutionId" AS agentGraphExecutionId,
ne."agentNodeId" AS agentNodeId,
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
ne."addedTime" AS addedTime,
ne."queuedTime" AS queuedTime,
ne."startedTime" AS startedTime,
ne."endedTime" AS endedTime,
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
b."name" AS blockName,
b."id" AS blockId,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM eio."data"::text),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage,
eio."data" AS errorMessage
FROM platform."AgentNodeExecution" ne
LEFT JOIN platform."AgentNode" nd
ON ne."agentNodeId" = nd."id"
LEFT JOIN platform."AgentBlock" b
ON nd."agentBlockId" = b."id"
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
ON eio."referencedByOutputExecId" = ne."id"
AND eio."name" = 'error'
AND ne."executionStatus" = 'FAILED'
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,97 +0,0 @@
-- =============================================================
-- View: analytics.retention_agent
-- Looker source alias: ds35 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention broken down per individual agent.
-- Cohort = week of a user's first use of THAT specific agent.
-- Tells you which agents keep users coming back vs. one-shot
-- use. Only includes cohorts from the last 180 days.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records (user × agent × time)
-- platform.AgentGraph — Agent names
--
-- OUTPUT COLUMNS
-- agent_id TEXT Agent graph UUID
-- agent_label TEXT 'AgentName [first8chars]'
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
-- cohort_week_start DATE Week users first ran this agent
-- cohort_label TEXT ISO week label
-- cohort_label_n TEXT ISO week label with cohort size
-- user_lifetime_week INT Weeks since first use of this agent
-- cohort_users BIGINT Users in this cohort for this agent
-- active_users BIGINT Users who ran the agent again in week k
-- retention_rate FLOAT active_users / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
-- agent_total_users BIGINT Total users across all cohorts for this agent
--
-- EXAMPLE QUERIES
-- -- Best-retained agents at week 2
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
-- FROM analytics.retention_agent
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
--
-- -- Agents with most unique users
-- SELECT DISTINCT agent_label, agent_total_users
-- FROM analytics.retention_agent
-- ORDER BY agent_total_users DESC LIMIT 20;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e
),
first_use AS (
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1,2
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
),
active_counts AS (
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
),
cohort_sizes AS (
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
),
cohort_caps AS (
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
),
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
SELECT
g.agent_id,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(ac.active_users,0) AS active_users,
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
COALESCE(atu.agent_total_users,0) AS agent_total_users
FROM grid g
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,81 +0,0 @@
-- =============================================================
-- View: analytics.retention_execution_daily
-- Looker source alias: ds111 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on agent executions.
-- Cohort anchor = day of user's FIRST ever execution.
-- Only includes cohorts from the last 90 days, up to day 30.
-- Great for early engagement analysis (did users run another
-- agent the next day?).
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_daily.
-- cohort_day_start = day of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Day-3 execution retention
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
-- FROM analytics.retention_execution_daily
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('day', e."createdAt")::date AS day_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fe.cohort_day_start,
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_exec fe USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -1,81 +0,0 @@
-- =============================================================
-- View: analytics.retention_execution_weekly
-- Looker source alias: ds92 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on agent executions.
-- Cohort anchor = week of user's FIRST ever agent execution
-- (not first login). Only includes cohorts from the last 180 days.
-- Useful when you care about product engagement, not just visits.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_weekly.
-- cohort_week_start = week of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Week-2 execution retention
-- SELECT cohort_label, retention_rate_bounded
-- FROM analytics.retention_execution_weekly
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fe.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,94 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_daily
-- Looker source alias: ds112 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on login sessions.
-- Same logic as retention_login_weekly but at day granularity,
-- showing up to day 30 for cohorts from the last 90 days.
-- Useful for analysing early activation (days 1-7) in detail.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
-- cohort_day_start DATE First day the cohort logged in
-- cohort_label TEXT Date string (e.g. '2025-03-01')
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
-- user_lifetime_day INT Days since first login (0 = signup day)
-- cohort_users BIGINT Total users in cohort
-- active_users_bounded BIGINT Users active on exactly day k
-- retained_users_unbounded BIGINT Users active any time on/after day k
-- retention_rate_bounded FLOAT bounded / cohort_users
-- retention_rate_unbounded FLOAT unbounded / cohort_users
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
--
-- EXAMPLE QUERIES
-- -- Day-1 retention rate (came back next day)
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
-- FROM analytics.retention_login_daily
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
--
-- -- Average retention curve across all cohorts
-- SELECT user_lifetime_day,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
-- FROM analytics.retention_login_daily
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('day', s.created_at)::date AS day_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fl.cohort_day_start,
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_login fl USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -1,96 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_onboarded_weekly
-- Looker source alias: ds101 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention from login sessions, restricted to
-- users who "onboarded" — defined as running at least one
-- agent within 365 days of their first login.
-- Filters out users who signed up but never activated,
-- giving a cleaner view of engaged-user retention.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
-- platform.AgentGraphExecution — Used to identify onboarders
--
-- OUTPUT COLUMNS
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
-- retention_rate_bounded, retention_rate_unbounded, etc.)
-- Only difference: cohort is filtered to onboarded users only.
--
-- EXAMPLE QUERIES
-- -- Compare week-4 retention: all users vs onboarded only
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
-- UNION ALL
-- SELECT 'onboarded', AVG(retention_rate_bounded)
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login_all AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
onboarders AS (
SELECT fl.user_id FROM first_login_all fl
WHERE EXISTS (
SELECT 1 FROM platform."AgentGraphExecution" e
WHERE e."userId"::text = fl.user_id
AND e."createdAt" >= fl.first_login_time
AND e."createdAt" < fl.first_login_time
+ make_interval(days => (SELECT onboarding_window_days FROM params))
)
),
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,103 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_weekly
-- Looker source alias: ds83 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on login sessions.
-- Users are grouped by the ISO week of their first ever login.
-- For each cohort × lifetime-week combination, outputs both:
-- - bounded rate: % active in exactly that week
-- - unbounded rate: % who were ever active on or after that week
-- Weeks are capped to the cohort's actual age (no future data points).
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- HOW TO READ THE OUTPUT
-- cohort_week_start The Monday of the week users first logged in
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
-- retention_rate_bounded = active_users_bounded / cohort_users
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
--
-- OUTPUT COLUMNS
-- cohort_week_start DATE First day of the cohort's signup week
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
-- user_lifetime_week INT Weeks since first login (0 = signup week)
-- cohort_users BIGINT Total users in this cohort (denominator)
-- active_users_bounded BIGINT Users active in exactly week k
-- retained_users_unbounded BIGINT Users active any time on/after week k
-- retention_rate_bounded FLOAT bounded active / cohort_users
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
--
-- EXAMPLE QUERIES
-- -- Week-1 retention rate per cohort
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
-- FROM analytics.retention_login_weekly
-- WHERE user_lifetime_week = 1
-- ORDER BY cohort_week_start;
--
-- -- Overall average retention curve (all cohorts combined)
-- SELECT user_lifetime_week,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
-- FROM analytics.retention_login_weekly
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week

View File

@@ -1,71 +0,0 @@
-- =============================================================
-- View: analytics.user_block_spending
-- Looker source alias: ds6 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per credit transaction (last 90 days).
-- Shows how users spend credits broken down by block type,
-- LLM provider and model. Joins node execution stats for
-- token-level detail.
--
-- SOURCE TABLES
-- platform.CreditTransaction — Credit debit/credit records
-- platform.AgentNodeExecution — Node execution stats (for token counts)
--
-- OUTPUT COLUMNS
-- transactionKey TEXT Unique transaction identifier
-- userId TEXT User who was charged
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
-- transactionTime TIMESTAMPTZ When the transaction was recorded
-- blockId TEXT Block UUID that triggered the spend
-- blockName TEXT Human-readable block name
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
-- node_exec_id TEXT Linked node execution UUID
-- llm_call_count INT LLM API calls made in that execution
-- llm_retry_count INT LLM retries in that execution
-- llm_input_token_count INT Input tokens consumed
-- llm_output_token_count INT Output tokens produced
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend per user (last 90 days)
-- SELECT "userId", SUM("negativeAmount") AS total_spent
-- FROM analytics.user_block_spending
-- WHERE "transactionType" = 'USAGE'
-- GROUP BY 1 ORDER BY total_spent DESC;
--
-- -- Spend by LLM provider + model
-- SELECT "llm_provider", "llm_model",
-- SUM("negativeAmount") AS total_cost,
-- SUM("llm_input_token_count") AS input_tokens,
-- SUM("llm_output_token_count") AS output_tokens
-- FROM analytics.user_block_spending
-- WHERE "llm_provider" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
-- =============================================================
SELECT
c."transactionKey" AS transactionKey,
c."userId" AS userId,
c."amount" AS amount,
c."amount" * -1 AS negativeAmount,
c."type" AS transactionType,
c."createdAt" AS transactionTime,
c.metadata->>'block_id' AS blockId,
c.metadata->>'block' AS blockName,
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
c.metadata->'input'->>'model' AS llm_model,
c.metadata->>'node_exec_id' AS node_exec_id,
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
FROM platform."CreditTransaction" c
LEFT JOIN platform."AgentNodeExecution" ne
ON (c.metadata->>'node_exec_id') = ne."id"::text
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,45 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding
-- Looker source alias: ds68 | Charts: 3
-- =============================================================
-- DESCRIPTION
-- One row per user onboarding record. Contains the user's
-- stated usage reason, selected integrations, completed
-- onboarding steps and optional first agent selection.
-- Full history (no date filter) since onboarding happens
-- once per user.
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding state per user
--
-- OUTPUT COLUMNS
-- id TEXT Onboarding record UUID
-- createdAt TIMESTAMPTZ When onboarding started
-- updatedAt TIMESTAMPTZ Last update to onboarding state
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
-- integrations TEXT[] Array of integration names the user selected
-- userId TEXT User UUID
-- completedSteps TEXT[] Array of onboarding step enums completed
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
--
-- EXAMPLE QUERIES
-- -- Usage reason breakdown
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
--
-- -- Completion rate per step
-- SELECT step, COUNT(*) AS users_completed
-- FROM analytics.user_onboarding
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
-- GROUP BY 1 ORDER BY users_completed DESC;
-- =============================================================
SELECT
id,
"createdAt",
"updatedAt",
"usageReason",
integrations,
"userId",
"completedSteps",
"selectedStoreListingVersionId"
FROM platform."UserOnboarding"

View File

@@ -1,100 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding_funnel
-- Looker source alias: ds74 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated onboarding funnel showing how many users
-- completed each step and the drop-off percentage from the
-- previous step. One row per onboarding step (all 22 steps
-- always present, even with 0 completions — prevents sparse
-- gaps from making LAG compare the wrong predecessors).
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding records with completedSteps array
--
-- OUTPUT COLUMNS
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
-- step_order INT Numeric position in the funnel (1=first, 22=last)
-- users_completed BIGINT Distinct users who completed this step
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
--
-- STEP ORDER
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
-- 7 CONGRATS 15 VISIT_COPILOT
-- 8 GET_RESULTS 16 RE_RUN_AGENT
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full funnel
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
--
-- -- Biggest drop-off point
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
-- ORDER BY pct_from_prev ASC LIMIT 3;
-- =============================================================
WITH all_steps AS (
-- Complete ordered grid of all 22 steps so zero-completion steps
-- are always present, keeping LAG comparisons correct.
SELECT step_name, step_order
FROM (VALUES
('WELCOME', 1),
('USAGE_REASON', 2),
('INTEGRATIONS', 3),
('AGENT_CHOICE', 4),
('AGENT_NEW_RUN', 5),
('AGENT_INPUT', 6),
('CONGRATS', 7),
('GET_RESULTS', 8),
('MARKETPLACE_VISIT', 9),
('MARKETPLACE_ADD_AGENT', 10),
('MARKETPLACE_RUN_AGENT', 11),
('BUILDER_OPEN', 12),
('BUILDER_SAVE_AGENT', 13),
('BUILDER_RUN_AGENT', 14),
('VISIT_COPILOT', 15),
('RE_RUN_AGENT', 16),
('SCHEDULE_AGENT', 17),
('RUN_AGENTS', 18),
('RUN_3_DAYS', 19),
('TRIGGER_WEBHOOK', 20),
('RUN_14_DAYS', 21),
('RUN_AGENTS_100', 22)
) AS t(step_name, step_order)
),
raw AS (
SELECT
u."userId",
step_txt::text AS step
FROM platform."UserOnboarding" u
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
),
step_counts AS (
SELECT step, COUNT(DISTINCT "userId") AS users_completed
FROM raw GROUP BY step
),
funnel AS (
SELECT
a.step_name AS step,
a.step_order,
COALESCE(sc.users_completed, 0) AS users_completed,
ROUND(
100.0 * COALESCE(sc.users_completed, 0)
/ NULLIF(
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
0
),
2
) AS pct_from_prev
FROM all_steps a
LEFT JOIN step_counts sc ON sc.step = a.step_name
)
SELECT * FROM funnel ORDER BY step_order

View File

@@ -1,41 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding_integration
-- Looker source alias: ds75 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated count of users who selected each integration
-- during onboarding. One row per integration type, sorted
-- by popularity.
--
-- SOURCE TABLES
-- platform.UserOnboarding — integrations array column
--
-- OUTPUT COLUMNS
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
-- users_with_integration BIGINT Distinct users who selected this integration
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full integration popularity ranking
-- SELECT * FROM analytics.user_onboarding_integration;
--
-- -- Top 5 integrations
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
-- =============================================================
WITH exploded AS (
SELECT
u."userId" AS user_id,
UNNEST(u."integrations") AS integration
FROM platform."UserOnboarding" u
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
)
SELECT
integration,
COUNT(DISTINCT user_id) AS users_with_integration
FROM exploded
WHERE integration IS NOT NULL AND integration <> ''
GROUP BY integration
ORDER BY users_with_integration DESC

View File

@@ -1,145 +0,0 @@
-- =============================================================
-- View: analytics.users_activities
-- Looker source alias: ds56 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per user with lifetime activity summary.
-- Joins login sessions with agent graphs, executions and
-- node-level runs to give a full picture of how engaged
-- each user is. Includes a convenience flag for 7-day
-- activation (did the user return at least 7 days after
-- their first login?).
--
-- SOURCE TABLES
-- auth.sessions — Login/session records
-- platform.AgentGraph — Graphs (agents) built by the user
-- platform.AgentGraphExecution — Agent run history
-- platform.AgentNodeExecution — Individual block execution history
--
-- PERFORMANCE NOTE
-- Each CTE aggregates its own table independently by userId.
-- This avoids the fan-out that occurs when driving every join
-- from user_logins across the two largest tables
-- (AgentGraphExecution and AgentNodeExecution).
--
-- OUTPUT COLUMNS
-- user_id TEXT Supabase user UUID
-- first_login_time TIMESTAMPTZ First ever session created_at
-- last_login_time TIMESTAMPTZ Most recent session created_at
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
-- agent_runs BIGINT Total graph execution count (0 if none)
-- node_execution_count BIGINT Total node executions across all runs
-- node_execution_failed BIGINT Node executions with FAILED status
-- node_execution_completed BIGINT Node executions with COMPLETED status
-- node_execution_terminated BIGINT Node executions with TERMINATED status
-- node_execution_queued BIGINT Node executions with QUEUED status
-- node_execution_running BIGINT Node executions with RUNNING status
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
-- node_execution_review BIGINT Node executions with REVIEW status
--
-- EXAMPLE QUERIES
-- -- Users who ran at least one agent and returned after 7 days
-- SELECT COUNT(*) FROM analytics.users_activities
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
--
-- -- Top 10 most active users by agent runs
-- SELECT user_id, agent_runs, node_execution_count
-- FROM analytics.users_activities
-- ORDER BY agent_runs DESC LIMIT 10;
--
-- -- 7-day activation rate
-- SELECT
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
-- AS activation_rate
-- FROM analytics.users_activities;
-- =============================================================
WITH user_logins AS (
SELECT
user_id::text AS user_id,
MIN(created_at) AS first_login_time,
MAX(created_at) AS last_login_time,
GREATEST(
MAX(refreshed_at)::timestamptz,
MAX(created_at)::timestamptz
) AS last_visit_time
FROM auth.sessions
GROUP BY user_id
),
user_agents AS (
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
SELECT
"userId"::text AS user_id,
MAX("updatedAt") AS last_agent_save_time,
COUNT(DISTINCT "id") AS agent_count
FROM platform."AgentGraph"
WHERE "isActive"
GROUP BY "userId"
),
user_graph_runs AS (
-- Aggregate AgentGraphExecution directly by userId
SELECT
"userId"::text AS user_id,
MIN("createdAt") AS first_agent_run_time,
MAX("createdAt") AS last_agent_run_time,
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
COUNT("id") AS agent_runs
FROM platform."AgentGraphExecution"
GROUP BY "userId"
),
user_node_runs AS (
-- Aggregate AgentNodeExecution directly; resolve userId via a
-- single join to AgentGraphExecution instead of fanning out from
-- user_logins through both large tables.
SELECT
g."userId"::text AS user_id,
COUNT(*) AS node_execution_count,
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
FROM platform."AgentNodeExecution" n
JOIN platform."AgentGraphExecution" g
ON g."id" = n."agentGraphExecutionId"
GROUP BY g."userId"
)
SELECT
ul.user_id,
ul.first_login_time,
ul.last_login_time,
ul.last_visit_time,
ua.last_agent_save_time,
COALESCE(ua.agent_count, 0) AS agent_count,
gr.first_agent_run_time,
gr.last_agent_run_time,
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
COALESCE(gr.agent_runs, 0) AS agent_runs,
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
CASE
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
ELSE NULL
END AS is_active_after_7d,
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
COALESCE(nr.node_execution_review, 0) AS node_execution_review
FROM user_logins ul
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id

View File

@@ -58,56 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **f-strings vs printf syntax in log statements** — Use `%s` for deferred interpolation in `debug` statements, f-strings elsewhere for readability: `logger.debug("Processing %s items", count)`, `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
- **Top-down ordering** — define the main/public function or class first, then the helpers it uses below. A reader should encounter high-level logic before implementation details.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
### Test-Driven Development (TDD)
When fixing a bug or adding a feature, write the test **before** the implementation:
```python
# 1. Write a failing test marked xfail
@pytest.mark.xfail(reason="Bug #1234: widget crashes on empty input")
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
# 2. Run it — confirm it fails (XFAIL)
# poetry run pytest path/to/test.py::test_widget_handles_empty_input -xvs
# 3. Implement the fix
# 4. Remove xfail, run again — confirm it passes
def test_widget_handles_empty_input():
result = widget.process("")
assert result == Widget.EMPTY_RESULT
```
This catches regressions and proves the fix actually works. **Every bug fix should include a test that would have caught it.**
## Database Schema
@@ -203,16 +157,6 @@ yield "image_url", result_url
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware

View File

@@ -50,7 +50,7 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
# =============================== DB MIGRATOR =============================== #
@@ -82,7 +82,7 @@ RUN pip3 install prisma>=0.15.0 --break-system-packages
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/scripts/gen_prisma_types_stub.py ./scripts/
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
COPY autogpt_platform/backend/migrations ./migrations
# ============================== BACKEND SERVER ============================== #
@@ -121,37 +121,19 @@ RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install agent-browser (Copilot browser tool) + Chromium.
# On amd64: install runtime libs + run `agent-browser install` to download
# Chrome for Testing (pinned version, tested with Playwright).
# On arm64: install system chromium package — Chrome for Testing has no ARM64
# binary. AGENT_BROWSER_EXECUTABLE_PATH is set at runtime by the entrypoint
# script (below) to redirect agent-browser to the system binary.
ARG TARGETARCH
RUN apt-get update \
&& if [ "$TARGETARCH" = "arm64" ]; then \
apt-get install -y --no-install-recommends chromium fonts-liberation; \
else \
apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1; \
fi \
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
RUN apt-get update && apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1 \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g agent-browser \
&& ([ "$TARGETARCH" = "arm64" ] || agent-browser install) \
&& agent-browser install \
&& rm -rf /tmp/* /root/.npm
# On arm64 the system chromium is at /usr/bin/chromium; set
# AGENT_BROWSER_EXECUTABLE_PATH so agent-browser's daemon uses it instead of
# Chrome for Testing (which has no ARM64 binary). On amd64 the variable is left
# unset so agent-browser uses the Chrome for Testing binary it downloaded above.
RUN printf '#!/bin/sh\n[ -x /usr/bin/chromium ] && export AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium\nexec "$@"\n' \
> /usr/local/bin/entrypoint.sh \
&& chmod +x /usr/local/bin/entrypoint.sh
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
@@ -173,5 +155,4 @@ RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
ENV PORT=8000
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["rest"]

View File

@@ -1,7 +1,7 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Optional, Sequence
from typing import Annotated, Any, Literal, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -9,10 +9,9 @@ from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.blocks
from backend.api.external.middleware import require_auth, require_permission
from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
@@ -231,13 +230,13 @@ async def get_graph_execution_results(
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: store_db.StoreAgentsSortOptions | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -279,7 +278,7 @@ async def get_store_agents(
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
@@ -307,13 +306,13 @@ async def get_store_agent(
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
@@ -349,7 +348,7 @@ async def get_store_creators(
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorDetails,
)
async def get_store_creator(

View File

@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=store_model.StoreListingsWithVersionsResponse,
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
):
"""
Get store listings with their version history for admins.
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
page_size: Number of items per page
Returns:
Paginated listings with their versions
StoreListingsWithVersionsResponse with listings and their versions
"""
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
try:
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=store_model.StoreSubmission,
)
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmissionAdminView:
):
"""
Review a store listing submission.
@@ -73,24 +84,31 @@ async def review_submission(
user_id: Authenticated admin user performing the review
Returns:
StoreSubmissionAdminView with updated review information
StoreSubmission with updated review information
"""
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
try:
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches whenever approval state changes, since store visibility can change
if state_changed:
store_cache.clear_all_caches()
return submission
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
@router.get(

View File

@@ -4,12 +4,14 @@ from difflib import SequenceMatcher
from typing import Any, Sequence, get_args, get_origin
import prisma
from prisma.enums import ContentType
from prisma.models import mv_suggested_blocks
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
@@ -22,7 +24,6 @@ from backend.blocks.llm import LlmModel
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
from backend.util.text import split_camelcase
from .model import (
BlockCategoryResponse,
@@ -270,7 +271,7 @@ async def _build_cached_search_results(
# Use hybrid search when query is present, otherwise list all blocks
if (include_blocks or include_integrations) and normalized_query:
block_results, block_total, integration_total = await _text_search_blocks(
block_results, block_total, integration_total = await _hybrid_search_blocks(
query=search_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
@@ -382,75 +383,117 @@ def _collect_block_results(
return results, block_count, integration_count
async def _text_search_blocks(
async def _hybrid_search_blocks(
*,
query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Search blocks using in-memory text matching over the block registry.
Search blocks using hybrid search with builder-specific filtering.
All blocks are already loaded in memory, so this is fast and reliable
regardless of whether OpenAI embeddings are available.
Uses unified_hybrid_search for semantic + lexical search, then applies
post-filtering for block/integration types and scoring adjustments.
Scoring:
- Base: text relevance via _score_primary_fields, plus BLOCK_SCORE_BOOST
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
to prioritize blocks over marketplace agents in combined results
- +30 for exact name match, +15 for prefix name match
- +20 if the block has an LlmModel field and the query matches an LLM model name
Args:
query: The search query string
include_blocks: Whether to include regular blocks
include_integrations: Whether to include integration blocks
Returns:
Tuple of (scored_items, block_count, integration_count)
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, 0, 0
return results, block_count, integration_count
normalized_query = query.strip().lower()
all_results, _, _ = _collect_block_results(
include_blocks=include_blocks,
include_integrations=include_integrations,
# Fetch more results to account for post-filtering
search_results, _ = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=150,
min_score=0.10,
)
# Load all blocks for getting BlockInfo
all_blocks = load_all_blocks()
for item in all_results:
block_info = item.item
assert isinstance(block_info, BlockInfo)
name = split_camelcase(block_info.name).lower()
for result in search_results:
block_id = result["content_id"]
# Build rich description including input field descriptions,
# matching the searchable text that the embedding pipeline uses
desc_parts = [block_info.description or ""]
block_cls = all_blocks.get(block_info.id)
if block_cls is not None:
block: AnyBlockSchema = block_cls()
desc_parts += [
f"{f}: {info.description}"
for f, info in block.input_schema.model_fields.items()
if info.description
]
description = " ".join(desc_parts).lower()
# Skip excluded blocks
if block_id in EXCLUDED_BLOCK_IDS:
continue
score = _score_primary_fields(name, description, normalized_query)
metadata = result.get("metadata", {})
hybrid_score = result.get("relevance", 0.0)
# Get the actual block class
if block_id not in all_blocks:
continue
block_cls = all_blocks[block_id]
block: AnyBlockSchema = block_cls()
if block.disabled:
continue
# Check block/integration filter using metadata
is_integration = metadata.get("is_integration", False)
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
# Get block info
block_info = block.get_info()
# Calculate final score: scale hybrid score and add builder-specific bonuses
# Hybrid scores are 0-1, builder scores were 0-200+
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
# Add LLM model match bonus
if block_cls is not None and _matches_llm_model(
block_cls().input_schema, normalized_query
):
score += 20
has_llm_field = metadata.get("has_llm_model_field", False)
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
final_score += 20
if score >= MIN_SCORE_FOR_FILTERED_RESULTS:
results.append(
_ScoredItem(
item=block_info,
filter_type=item.filter_type,
score=score + BLOCK_SCORE_BOOST,
sort_key=name,
)
# Add exact/prefix match bonus for deterministic tie-breaking
name = block_info.name.lower()
if name == normalized_query:
final_score += 30
elif name.startswith(normalized_query):
final_score += 15
# Track counts
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=final_score,
sort_key=name,
)
)
block_count = sum(1 for r in results if r.filter_type == "blocks")
integration_count = sum(1 for r in results if r.filter_type == "integrations")
return results, block_count, integration_count

View File

@@ -8,10 +8,10 @@ from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
@@ -25,16 +25,8 @@ from backend.copilot.model import (
delete_chat_session,
get_chat_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -59,7 +51,6 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.redis_client import get_redis_async
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
@@ -125,8 +116,6 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -136,7 +125,6 @@ class SessionSummaryResponse(BaseModel):
created_at: str
updated_at: str
title: str | None = None
is_processing: bool
class ListSessionsResponse(BaseModel):
@@ -153,20 +141,6 @@ class CancelSessionResponse(BaseModel):
reason: str | None = None
class UpdateSessionTitleRequest(BaseModel):
"""Request model for updating a session's title."""
title: str
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v: str) -> str:
stripped = v.strip()
if not stripped:
raise ValueError("Title must not be blank")
return stripped
# ========== Routes ==========
@@ -195,28 +169,6 @@ async def list_sessions(
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
# Batch-check Redis for active stream status on each session
processing_set: set[str] = set()
if sessions:
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
for session in sessions:
pipe.hget(
f"{config.session_meta_prefix}{session.session_id}",
"status",
)
statuses = await pipe.execute()
processing_set = {
session.session_id
for session, st in zip(sessions, statuses)
if st == "running"
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; defaulting to empty"
)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
@@ -224,7 +176,6 @@ async def list_sessions(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
is_processing=session.session_id in processing_set,
)
for session in sessions
],
@@ -236,7 +187,7 @@ async def list_sessions(
"/sessions",
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
user_id: Annotated[str, Depends(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -299,12 +250,12 @@ async def delete_session(
)
# Best-effort cleanup of the E2B sandbox (if any).
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
try:
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
await kill_sandbox(session_id, config.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
@@ -313,49 +264,12 @@ async def delete_session(
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
dependencies=[Security(auth.requires_user)],
status_code=200,
responses={404: {"description": "Session not found or access denied"}},
)
async def update_session_title_route(
session_id: str,
request: UpdateSessionTitleRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> dict:
"""
Update the title of a chat session.
Allows the user to rename their chat session.
Args:
session_id: The session ID to update.
request: Request body containing the new title.
user_id: The authenticated user's ID.
Returns:
dict: Status of the update.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
success = await update_session_title(session_id, user_id, request.title)
if not success:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return {"status": "ok"}
@router.get(
"/sessions/{session_id}",
)
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
@@ -396,10 +310,6 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -407,25 +317,6 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get(
"/usage",
)
async def get_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -435,7 +326,7 @@ async def get_copilot_usage(
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
@@ -480,7 +371,7 @@ async def cancel_session_task(
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str = Security(auth.get_user_id),
user_id: str | None = Depends(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
@@ -497,7 +388,7 @@ async def stream_chat_post(
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Authenticated user ID.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
@@ -506,7 +397,9 @@ async def stream_chat_post(
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
@@ -524,18 +417,6 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based).
# check_rate_limit short-circuits internally when both limits are 0.
if user_id:
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
@@ -770,7 +651,7 @@ async def stream_chat_post(
)
async def resume_session_stream(
session_id: str,
user_id: str = Security(auth.get_user_id),
user_id: str | None = Depends(auth.get_user_id),
):
"""
Resume an active stream for a session.
@@ -872,6 +753,7 @@ async def resume_session_stream(
@router.patch(
"/sessions/{session_id}/assign-user",
dependencies=[Security(auth.requires_user)],
status_code=200,
)
async def session_assign_user(
session_id: str,

View File

@@ -1,7 +1,4 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, and rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock
"""Tests for chat route file_ids validation and enrichment."""
import fastapi
import fastapi.testclient
@@ -20,7 +17,6 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
@@ -28,95 +24,7 @@ def setup_app_auth(mock_jwt_user):
app.dependency_overrides.clear()
def _mock_update_session_title(
mocker: pytest_mock.MockerFixture, *, success: bool = True
):
"""Mock update_session_title."""
return mocker.patch(
"backend.api.features.chat.routes.update_session_title",
new_callable=AsyncMock,
return_value=success,
)
# ─── Update title: success ─────────────────────────────────────────────
def test_update_title_success(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "My project"},
)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
def test_update_title_trims_whitespace(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": " trimmed "},
)
assert response.status_code == 200
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
def test_update_title_blank_rejected(
test_user_id: str,
) -> None:
"""Whitespace-only titles must be rejected before hitting the DB."""
response = client.patch(
"/sessions/sess-1/title",
json={"title": " "},
)
assert response.status_code == 422
def test_update_title_empty_rejected(
test_user_id: str,
) -> None:
response = client.patch(
"/sessions/sess-1/title",
json={"title": ""},
)
assert response.status_code == 422
# ─── Update title: session not found or wrong user → 404 ──────────────
def test_update_title_not_found(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
_mock_update_session_title(mocker, success=False)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "New name"},
)
assert response.status_code == 404
# ─── file_ids Pydantic validation ─────────────────────────────────────
# ---- file_ids Pydantic validation (B1) ----
def test_stream_chat_rejects_too_many_file_ids():
@@ -184,7 +92,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
assert response.status_code == 200
# ─── UUID format filtering ─────────────────────────────────────────────
# ---- UUID format filtering ----
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
@@ -223,7 +131,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ─── Cross-workspace file_ids ─────────────────────────────────────────
# ---- Cross-workspace file_ids ----
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
@@ -250,153 +158,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False
# ─── Rate limit → 429 ─────────────────────────────────────────────────
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
# Ensure the rate-limit branch is entered by setting a non-zero limit.
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
assert "daily" in response.json()["detail"].lower()
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
resets_at = datetime.now(UTC) + timedelta(days=3)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("weekly", resets_at),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"].lower()
assert "weekly" in detail
assert "resets in" in detail
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
"""The 429 response detail should include the human-readable reset time."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded(
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"]
assert "2h" in detail
assert "Resets in" in detail
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
def test_usage_rejects_unauthenticated_request() -> None:
"""GET /usage should return 401 when no valid JWT is provided."""
unauthenticated_app = fastapi.FastAPI()
unauthenticated_app.include_router(chat_routes.router)
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
response = unauthenticated_client.get("/usage")
assert response.status_code == 401

View File

@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "test_node_123"
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "node_exec_approved"
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
# Mock get_node_executions to return batch node data
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
# Create mock node executions for each review
mock_node_execs = []

View File

@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
get_node_executions,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
@@ -41,38 +36,6 @@ router = APIRouter(
)
async def _resolve_node_ids(
node_exec_ids: list[str],
graph_exec_id: str,
is_copilot: bool,
) -> dict[str, str]:
"""Resolve node_exec_id -> node_id for auto-approval records.
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
Graph executions look up node_id from NodeExecution records.
"""
if not node_exec_ids:
return {}
if is_copilot:
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
result = {}
for neid in node_exec_ids:
if neid in node_exec_map:
result[neid] = node_exec_map[neid]
else:
logger.error(
f"Failed to resolve node_id for {neid}: Node execution not found."
)
return result
@router.get(
"/pending",
summary="Get Pending Reviews",
@@ -147,16 +110,14 @@ async def list_pending_reviews_for_execution(
"""
# Verify user owns the graph execution before returning reviews
# (CoPilot synthetic IDs don't have graph execution records)
if not is_copilot_synthetic_id(graph_exec_id):
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
@@ -199,26 +160,30 @@ async def process_review_action(
)
graph_exec_id = next(iter(graph_exec_ids))
is_copilot = is_copilot_synthetic_id(graph_exec_id)
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
if not is_copilot:
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
@@ -271,7 +236,7 @@ async def process_review_action(
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval and resolve their node_ids
# Collect node_exec_ids that need auto-approval
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
@@ -279,16 +244,29 @@ async def process_review_action(
and auto_approve_requests.get(node_exec_id, False)
]
node_id_map = await _resolve_node_ids(
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
)
# Deduplicate by node_id — one auto-approval per node
# Batch-fetch node executions to get node_ids
nodes_needing_auto_approval: dict[str, Any] = {}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_id = node_id_map.get(node_exec_id)
if node_id and node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
@@ -303,11 +281,13 @@ async def process_review_action(
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
@@ -322,20 +302,22 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume graph execution only for real graph executions (not CoPilot)
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
if not is_copilot and updated_reviews:
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)

View File

@@ -8,6 +8,7 @@ import prisma.errors
import prisma.models
import prisma.types
import backend.api.features.store.exceptions as store_exceptions
import backend.api.features.store.image_gen as store_image_gen
import backend.api.features.store.media as store_media
import backend.data.graph as graph_db
@@ -250,7 +251,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
The requested LibraryAgent.
Raises:
NotFoundError: If the specified agent does not exist.
AgentNotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during retrieval.
"""
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
@@ -397,7 +398,6 @@ async def create_library_agent(
hitl_safe_mode: bool = True,
sensitive_action_safe_mode: bool = False,
create_library_agents_for_sub_graphs: bool = True,
folder_id: str | None = None,
) -> list[library_model.LibraryAgent]:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -414,18 +414,12 @@ async def create_library_agent(
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
Raises:
NotFoundError: If the specified agent does not exist.
AgentNotFoundError: If the specified agent does not exist.
DatabaseError: If there's an error during creation or if image generation fails.
"""
logger.info(
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
)
# Authorization: FK only checks existence, not ownership.
# Verify the folder belongs to this user to prevent cross-user nesting.
if folder_id:
await get_folder(folder_id, user_id)
graph_entries = (
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
)
@@ -438,6 +432,7 @@ async def create_library_agent(
isCreatedByUser=(user_id == user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {
@@ -453,11 +448,6 @@ async def create_library_agent(
sensitive_action_safe_mode=sensitive_action_safe_mode,
).model_dump()
),
**(
{"Folder": {"connect": {"id": folder_id}}}
if folder_id and graph_entry is graph
else {}
),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
@@ -539,7 +529,6 @@ async def update_agent_version_in_library(
async def create_graph_in_library(
graph: graph_db.Graph,
user_id: str,
folder_id: str | None = None,
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
"""Create a new graph and add it to the user's library."""
graph.version = 1
@@ -553,7 +542,6 @@ async def create_graph_in_library(
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
folder_id=folder_id,
)
if created_graph.is_active:
@@ -829,7 +817,7 @@ async def add_store_agent_to_library(
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
Raises:
NotFoundError: If the store listing or associated agent is not found.
AgentNotFoundError: If the store listing or associated agent is not found.
DatabaseError: If there's an issue creating the LibraryAgent record.
"""
logger.debug(
@@ -844,7 +832,7 @@ async def add_store_agent_to_library(
)
if not store_listing_version or not store_listing_version.AgentGraph:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise NotFoundError(
raise store_exceptions.AgentNotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
@@ -858,7 +846,7 @@ async def add_store_agent_to_library(
include_subgraphs=False,
)
if not graph_model:
raise NotFoundError(
raise store_exceptions.AgentNotFoundError(
f"Graph #{graph.id} v{graph.version} not found or accessible"
)
@@ -1493,67 +1481,6 @@ async def bulk_move_agents_to_folder(
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
def collect_tree_ids(
nodes: list[library_model.LibraryFolderTree],
visited: set[str] | None = None,
) -> list[str]:
"""Collect all folder IDs from a folder tree."""
if visited is None:
visited = set()
ids: list[str] = []
for n in nodes:
if n.id in visited:
continue
visited.add(n.id)
ids.append(n.id)
ids.extend(collect_tree_ids(n.children, visited))
return ids
async def get_folder_agent_summaries(
user_id: str, folder_id: str
) -> list[dict[str, str | None]]:
"""Get a lightweight list of agents in a folder (id, name, description)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, folder_id=folder_id, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_root_agent_summaries(
user_id: str,
) -> list[dict[str, str | None]]:
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
all_agents: list[library_model.LibraryAgent] = []
for page in itertools.count(1):
resp = await list_library_agents(
user_id=user_id, include_root_only=True, page=page
)
all_agents.extend(resp.agents)
if page >= resp.pagination.total_pages:
break
return [
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
]
async def get_folder_agents_map(
user_id: str, folder_ids: list[str]
) -> dict[str, list[dict[str, str | None]]]:
"""Get agent summaries for multiple folders concurrently."""
results = await asyncio.gather(
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
)
return dict(zip(folder_ids, results))
##############################################
########### Presets DB Functions #############
##############################################

View File

@@ -4,6 +4,7 @@ import prisma.enums
import prisma.models
import pytest
import backend.api.features.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
)
# Call function and verify exception
with pytest.raises(db.NotFoundError):
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly

View File

@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
image_url: str | None
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool = pydantic.Field(
description="Indicates whether the same user owns the corresponding graph"
)
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url_host(auth_server_url)
await validate_url(auth_server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url_host(registration_endpoint)
await validate_url(registration_endpoint, trusted_origins=[])
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url_host is enforced on all endpoints."""
"""Verify that validate_url is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -1,3 +1,5 @@
from typing import Literal
from backend.util.cache import cached
from . import db as store_db
@@ -21,7 +23,7 @@ def clear_all_caches():
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: store_db.StoreAgentsSortOptions | None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
search_query: str | None,
category: str | None,
page: int,
@@ -55,7 +57,7 @@ async def _get_cached_agent_details(
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: store_db.StoreCreatorsSortOptions | None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
page: int,
page_size: int,
):
@@ -73,4 +75,4 @@ async def _get_cached_store_creators(
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await store_db.get_store_creator(username=username.lower())
return await store_db.get_store_creator_details(username=username.lower())

View File

@@ -5,26 +5,16 @@ Pluggable system for different content sources (store agents, blocks, docs).
Each handler knows how to fetch and process its content type for embedding.
"""
from __future__ import annotations
import asyncio
import functools
import itertools
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, get_args, get_origin
from typing import Any, get_args, get_origin
from prisma.enums import ContentType
from backend.blocks import get_blocks
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.util.text import split_camelcase
if TYPE_CHECKING:
from backend.blocks._base import AnyBlockSchema
logger = logging.getLogger(__name__)
@@ -164,28 +154,6 @@ class StoreAgentHandler(ContentHandler):
}
@functools.lru_cache(maxsize=1)
def _get_enabled_blocks() -> dict[str, AnyBlockSchema]:
"""Return ``{block_id: block_instance}`` for all enabled, instantiable blocks.
Disabled blocks and blocks that fail to instantiate are silently skipped
(with a warning log), so callers never need their own try/except loop.
Results are cached for the process lifetime via ``lru_cache`` because
blocks are registered at import time and never change while running.
"""
enabled: dict[str, AnyBlockSchema] = {}
for block_id, block_cls in get_blocks().items():
try:
instance = block_cls()
except Exception as e:
logger.warning(f"Skipping block {block_id}: init failed: {e}")
continue
if not instance.disabled:
enabled[block_id] = instance
return enabled
class BlockHandler(ContentHandler):
"""Handler for block definitions (Python classes)."""
@@ -195,14 +163,16 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
# to_thread keeps the first (heavy) call off the event loop. On
# subsequent calls the lru_cache makes this a dict lookup, so the
# thread-pool overhead is negligible compared to the DB queries below.
enabled = await asyncio.to_thread(_get_enabled_blocks)
if not enabled:
from backend.blocks import get_blocks
# Get all available blocks
all_blocks = get_blocks()
# Check which ones have embeddings
if not all_blocks:
return []
block_ids = list(enabled.keys())
block_ids = list(all_blocks.keys())
# Query for existing embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
@@ -217,42 +187,52 @@ class BlockHandler(ContentHandler):
)
existing_ids = {row["contentId"] for row in existing_result}
missing_blocks = [
(block_id, block_cls)
for block_id, block_cls in all_blocks.items()
if block_id not in existing_ids
]
# Convert to ContentItem — disabled filtering already done by
# _get_enabled_blocks so batch_size won't be exhausted by disabled blocks.
missing = ((bid, b) for bid, b in enabled.items() if bid not in existing_ids)
# Convert to ContentItem
items = []
for block_id, block in itertools.islice(missing, batch_size):
for block_id, block_cls in missing_blocks[:batch_size]:
try:
block_instance = block_cls()
if block_instance.disabled:
continue
# Build searchable text from block metadata
if not block.name:
logger.warning(
f"Block {block_id} has no name — using block_id as fallback"
)
display_name = split_camelcase(block.name) if block.name else ""
parts = []
if display_name:
parts.append(display_name)
if block.description:
parts.append(block.description)
if block.categories:
parts.append(" ".join(str(cat.value) for cat in block.categories))
if block_instance.name:
parts.append(block_instance.name)
if block_instance.description:
parts.append(block_instance.description)
if block_instance.categories:
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block.input_schema.model_fields.items()
for field_name, field_info in block_input_fields.items()
if field_info.description
]
searchable_text = " ".join(parts)
categories_list = (
[cat.value for cat in block.categories] if block.categories else []
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = block.input_schema.get_credentials_fields_info()
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
@@ -263,7 +243,7 @@ class BlockHandler(ContentHandler):
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block.input_schema.model_fields.values()
for field in block_instance.input_schema.model_fields.values()
)
items.append(
@@ -272,13 +252,13 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": display_name or block.name or block_id,
"name": block_instance.name,
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None,
user_id=None, # Blocks are public
)
)
except Exception as e:
@@ -289,13 +269,22 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
enabled = await asyncio.to_thread(_get_enabled_blocks)
total_blocks = len(enabled)
from backend.blocks import get_blocks
all_blocks = get_blocks()
# Filter out disabled blocks - they're not indexed
enabled_block_ids = [
block_id
for block_id, block_cls in all_blocks.items()
if not block_cls().disabled
]
total_blocks = len(enabled_block_ids)
if total_blocks == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = list(enabled.keys())
block_ids = enabled_block_ids
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(

View File

@@ -1,5 +1,7 @@
"""
Tests for content handlers (blocks, store agents, documentation).
E2E tests for content handlers (blocks, store agents, documentation).
Tests the full flow: discovering content → generating embeddings → storing.
"""
from pathlib import Path
@@ -13,103 +15,15 @@ from backend.api.features.store.content_handlers import (
BlockHandler,
DocumentationHandler,
StoreAgentHandler,
_get_enabled_blocks,
)
@pytest.fixture(autouse=True)
def _clear_block_cache():
"""Clear the lru_cache on _get_enabled_blocks before each test."""
_get_enabled_blocks.cache_clear()
yield
_get_enabled_blocks.cache_clear()
# ---------------------------------------------------------------------------
# Helper to build a mock block class that returns a pre-configured instance
# ---------------------------------------------------------------------------
def _make_block_class(
*,
name: str = "Block",
description: str = "",
disabled: bool = False,
categories: list[MagicMock] | None = None,
fields: dict[str, str] | None = None,
raise_on_init: Exception | None = None,
) -> MagicMock:
cls = MagicMock()
if raise_on_init is not None:
cls.side_effect = raise_on_init
return cls
inst = MagicMock()
inst.name = name
inst.disabled = disabled
inst.description = description
inst.categories = categories or []
field_mocks = {
fname: MagicMock(description=fdesc) for fname, fdesc in (fields or {}).items()
}
inst.input_schema.model_fields = field_mocks
inst.input_schema.get_credentials_fields_info.return_value = {}
cls.return_value = inst
return cls
# ---------------------------------------------------------------------------
# _get_enabled_blocks
# ---------------------------------------------------------------------------
def test_get_enabled_blocks_filters_disabled():
"""Disabled blocks are excluded."""
blocks = {
"enabled": _make_block_class(name="E", disabled=False),
"disabled": _make_block_class(name="D", disabled=True),
}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
result = _get_enabled_blocks()
assert list(result.keys()) == ["enabled"]
def test_get_enabled_blocks_skips_broken():
"""Blocks that raise on init are skipped without crashing."""
blocks = {
"good": _make_block_class(name="Good"),
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
result = _get_enabled_blocks()
assert list(result.keys()) == ["good"]
def test_get_enabled_blocks_cached():
"""_get_enabled_blocks() calls get_blocks() only once across multiple calls."""
blocks = {"b1": _make_block_class(name="B1")}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
) as mock_get_blocks:
result1 = _get_enabled_blocks()
result2 = _get_enabled_blocks()
assert result1 is result2
mock_get_blocks.assert_called_once()
# ---------------------------------------------------------------------------
# StoreAgentHandler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_get_missing_items(mocker):
"""Test StoreAgentHandler fetches approved agents without embeddings."""
handler = StoreAgentHandler()
# Mock database query
mock_missing = [
{
"id": "agent-1",
@@ -140,7 +54,9 @@ async def test_store_agent_handler_get_stats(mocker):
"""Test StoreAgentHandler returns correct stats."""
handler = StoreAgentHandler()
# Mock approved count query
mock_approved = [{"count": 50}]
# Mock embedded count query
mock_embedded = [{"count": 30}]
with patch(
@@ -154,130 +70,74 @@ async def test_store_agent_handler_get_stats(mocker):
assert stats["without_embeddings"] == 20
# ---------------------------------------------------------------------------
# BlockHandler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items():
async def test_block_handler_get_missing_items(mocker):
"""Test BlockHandler discovers blocks without embeddings."""
handler = BlockHandler()
blocks = {
"block-uuid-1": _make_block_class(
name="CalculatorBlock",
description="Performs calculations",
categories=[MagicMock(value="MATH")],
fields={"expression": "Math expression to evaluate"},
),
}
# Mock get_blocks to return test blocks
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Calculator Block"
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
# Mock existing embeddings query (no embeddings exist)
mock_existing = []
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
return_value=mock_existing,
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "block-uuid-1"
assert items[0].content_type == ContentType.BLOCK
# CamelCase should be split in searchable text and metadata name
assert "Calculator Block" in items[0].searchable_text
assert "Performs calculations" in items[0].searchable_text
assert "MATH" in items[0].searchable_text
assert "expression: Math expression" in items[0].searchable_text
assert items[0].metadata["name"] == "Calculator Block"
assert items[0].user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items_splits_camelcase():
"""CamelCase block names are split for better search indexing."""
handler = BlockHandler()
blocks = {
"ai-block": _make_block_class(name="AITextGeneratorBlock"),
}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert "AI Text Generator Block" in items[0].searchable_text
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items_batch_size_zero():
"""batch_size=0 returns an empty list; the DB is still queried to find missing IDs."""
handler = BlockHandler()
blocks = {"b1": _make_block_class(name="B1")}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
) as mock_query:
items = await handler.get_missing_items(batch_size=0)
assert items == []
# DB query is still issued to learn which blocks lack embeddings;
# the empty result comes from itertools.islice limiting to 0 items.
mock_query.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_disabled_dont_exhaust_batch():
"""Disabled blocks don't consume batch budget, so enabled blocks get indexed."""
handler = BlockHandler()
# 5 disabled + 3 enabled, batch_size=2
blocks = {
**{
f"dis-{i}": _make_block_class(name=f"D{i}", disabled=True) for i in range(5)
},
**{f"en-{i}": _make_block_class(name=f"E{i}") for i in range(3)},
}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=2)
assert len(items) == 2
assert all(item.content_id.startswith("en-") for item in items)
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_stats():
async def test_block_handler_get_stats(mocker):
"""Test BlockHandler returns correct stats."""
handler = BlockHandler()
blocks = {
"block-1": _make_block_class(name="B1"),
"block-2": _make_block_class(name="B2"),
"block-3": _make_block_class(name="B3"),
# Mock get_blocks - each block class returns an instance with disabled=False
def make_mock_block_class():
mock_class = MagicMock()
mock_instance = MagicMock()
mock_instance.disabled = False
mock_class.return_value = mock_instance
return mock_class
mock_blocks = {
"block-1": make_mock_block_class(),
"block-2": make_mock_block_class(),
"block-3": make_mock_block_class(),
}
# Mock embedded count query (2 blocks have embeddings)
mock_embedded = [{"count": 2}]
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
@@ -290,123 +150,21 @@ async def test_block_handler_get_stats():
assert stats["without_embeddings"] == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_stats_skips_broken():
"""get_stats skips broken blocks instead of crashing."""
handler = BlockHandler()
blocks = {
"good": _make_block_class(name="Good"),
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
}
mock_embedded = [{"count": 1}]
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
):
stats = await handler.get_stats()
assert stats["total"] == 1 # only the good block
assert stats["with_embeddings"] == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_none_name():
"""When block.name is None the fallback display name logic is used."""
handler = BlockHandler()
blocks = {
"none-name-block": _make_block_class(
name="placeholder", # will be overridden to None below
description="A block with no name",
),
}
# Override the name to None after construction so _make_block_class
# doesn't interfere with the mock wiring.
blocks["none-name-block"].return_value.name = None
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
# display_name should be "" because block.name is None
# searchable_text should still contain the description
assert "A block with no name" in items[0].searchable_text
# metadata["name"] falls back to block_id when both display_name
# and block.name are falsy, ensuring it is always a non-empty string.
assert items[0].metadata["name"] == "none-name-block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
handler = BlockHandler()
blocks = {"block-minimal": _make_block_class(name="Minimal Block")}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].searchable_text == "Minimal Block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_skips_failed_blocks():
"""Test BlockHandler skips blocks that fail to instantiate."""
handler = BlockHandler()
blocks = {
"good-block": _make_block_class(name="Good Block", description="Works fine"),
"bad-block": _make_block_class(raise_on_init=Exception("Instantiation failed")),
}
with patch(
"backend.api.features.store.content_handlers.get_blocks", return_value=blocks
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "good-block"
# ---------------------------------------------------------------------------
# DocumentationHandler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
"""Test DocumentationHandler discovers docs without embeddings."""
handler = DocumentationHandler()
# Create temporary docs directory with test files
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
# Mock _get_docs_root to return temp dir
with patch.object(handler, "_get_docs_root", return_value=docs_root):
# Mock existing embeddings query (no embeddings exist)
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
@@ -415,6 +173,7 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
assert len(items) == 2
# Check guide.md (content_id format: doc_path::section_index)
guide_item = next(
(item for item in items if item.content_id == "guide.md::0"), None
)
@@ -425,6 +184,7 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
assert guide_item.metadata["doc_title"] == "Getting Started"
assert guide_item.user_id is None
# Check api.mdx (content_id format: doc_path::section_index)
api_item = next(
(item for item in items if item.content_id == "api.mdx::0"), None
)
@@ -437,12 +197,14 @@ async def test_documentation_handler_get_stats(tmp_path, mocker):
"""Test DocumentationHandler returns correct stats."""
handler = DocumentationHandler()
# Create temporary docs directory
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "doc1.md").write_text("# Doc 1")
(docs_root / "doc2.md").write_text("# Doc 2")
(docs_root / "doc3.mdx").write_text("# Doc 3")
# Mock embedded count query (1 doc has embedding)
mock_embedded = [{"count": 1}]
with patch.object(handler, "_get_docs_root", return_value=docs_root):
@@ -462,11 +224,13 @@ async def test_documentation_handler_title_extraction(tmp_path):
"""Test DocumentationHandler extracts title from markdown heading."""
handler = DocumentationHandler()
# Test with heading
doc_with_heading = tmp_path / "with_heading.md"
doc_with_heading.write_text("# My Title\n\nContent here")
title = handler._extract_doc_title(doc_with_heading)
assert title == "My Title"
# Test without heading
doc_without_heading = tmp_path / "no-heading.md"
doc_without_heading.write_text("Just content, no heading")
title = handler._extract_doc_title(doc_without_heading)
@@ -478,6 +242,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
"""Test DocumentationHandler chunks markdown by headings."""
handler = DocumentationHandler()
# Test document with multiple sections
doc_with_sections = tmp_path / "sections.md"
doc_with_sections.write_text(
"# Document Title\n\n"
@@ -489,6 +254,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
)
sections = handler._chunk_markdown_by_headings(doc_with_sections)
# Should have 3 sections: intro (with doc title), section one, section two
assert len(sections) == 3
assert sections[0].title == "Document Title"
assert sections[0].index == 0
@@ -502,6 +268,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
assert sections[2].index == 2
assert "Content for section two" in sections[2].content
# Test document without headings
doc_no_sections = tmp_path / "no-sections.md"
doc_no_sections.write_text("Just plain content without any headings.")
sections = handler._chunk_markdown_by_headings(doc_no_sections)
@@ -515,39 +282,21 @@ async def test_documentation_handler_section_content_ids():
"""Test DocumentationHandler creates and parses section content IDs."""
handler = DocumentationHandler()
# Test making content ID
content_id = handler._make_section_content_id("docs/guide.md", 2)
assert content_id == "docs/guide.md::2"
# Test parsing content ID
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
assert doc_path == "docs/guide.md"
assert section_index == 2
# Test parsing legacy format (no section index)
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
assert doc_path == "docs/old-format.md"
assert section_index == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_missing_docs_directory():
"""Test DocumentationHandler handles missing docs directory gracefully."""
handler = DocumentationHandler()
fake_path = Path("/nonexistent/docs")
with patch.object(handler, "_get_docs_root", return_value=fake_path):
items = await handler.get_missing_items(batch_size=10)
assert items == []
stats = await handler.get_stats()
assert stats["total"] == 0
assert stats["with_embeddings"] == 0
assert stats["without_embeddings"] == 0
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_content_handlers_registry():
"""Test all content types are registered."""
@@ -558,3 +307,88 @@ async def test_content_handlers_registry():
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
handler = BlockHandler()
# Mock block with empty values (all attributes exist but are falsy)
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].searchable_text == "Minimal Block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_skips_failed_blocks():
"""Test BlockHandler skips blocks that fail to instantiate."""
handler = BlockHandler()
# Mock one good block and one bad block
good_block = MagicMock()
good_instance = MagicMock()
good_instance.name = "Good Block"
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()
bad_block.side_effect = Exception("Instantiation failed")
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
# Should only get the good block
assert len(items) == 1
assert items[0].content_id == "good-block"
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_missing_docs_directory():
"""Test DocumentationHandler handles missing docs directory gracefully."""
handler = DocumentationHandler()
# Mock _get_docs_root to return non-existent path
fake_path = Path("/nonexistent/docs")
with patch.object(handler, "_get_docs_root", return_value=fake_path):
items = await handler.get_missing_items(batch_size=10)
assert items == []
stats = await handler.get_stats()
assert stats["total"] == 0
assert stats["with_embeddings"] == 0
assert stats["without_embeddings"] == 0

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
mock_agents = [
prisma.models.StoreAgent(
listing_id="test-id",
listing_version_id="version123",
storeListingVersionId="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video=None,
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
graph_id="test-graph-id",
graph_versions=["1"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=False,
use_for_onboarding=False,
useForOnboarding=False,
)
]
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agent_details(mocker):
# Mock data - StoreAgent view already contains the active version data
# Mock data
mock_agent = prisma.models.StoreAgent(
listing_id="test-id",
listing_version_id="version123",
storeListingVersionId="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
@@ -85,38 +85,102 @@ async def test_get_store_agent_details(mocker):
runs=10,
rating=4.5,
versions=["1.0"],
graph_id="test-graph-id",
graph_versions=["1"],
agentGraphVersions=["1"],
agentGraphId="test-graph-id",
updated_at=datetime.now(),
is_available=True,
use_for_onboarding=False,
is_available=False,
useForOnboarding=False,
)
# Mock StoreAgent prisma call
# Mock active version agent (what we want to return for active version)
mock_active_agent = prisma.models.StoreAgent(
listing_id="test-id",
storeListingVersionId="active-version-id",
slug="test-agent",
agent_name="Test Agent Active",
agent_video="active_video.mp4",
agent_image=["active_image.jpg"],
featured=False,
creator_username="creator",
creator_avatar="avatar.jpg",
sub_heading="Test heading active",
description="Test description active",
categories=["test"],
runs=15,
rating=4.8,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id-active",
updated_at=datetime.now(),
is_available=True,
useForOnboarding=False,
)
# Create a mock StoreListing result
mock_store_listing = mocker.MagicMock()
mock_store_listing.activeVersionId = "active-version-id"
mock_store_listing.hasApprovedVersion = True
mock_store_listing.ActiveVersion = mocker.MagicMock()
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
# Mock StoreAgent prisma call - need to handle multiple calls
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Set up side_effect to return different results for different calls
def mock_find_first_side_effect(*args, **kwargs):
where_clause = kwargs.get("where", {})
if "storeListingVersionId" in where_clause:
# Second call for active version
return mock_active_agent
else:
# First call for initial lookup
return mock_agent
mock_store_agent.return_value.find_first = mocker.AsyncMock(
side_effect=mock_find_first_side_effect
)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_store_listing
)
# Call function
result = await db.get_store_agent_details("creator", "test-agent")
# Verify results - constructed from the StoreAgent view
# Verify results - should use active version data
assert result.slug == "test-agent"
assert result.agent_name == "Test Agent"
assert result.active_version_id == "version123"
assert result.agent_name == "Test Agent Active" # From active version
assert result.active_version_id == "active-version-id"
assert result.has_approved_version is True
assert result.store_listing_version_id == "version123"
assert result.graph_id == "test-graph-id"
assert result.runs == 10
assert result.rating == 4.5
assert (
result.store_listing_version_id == "active-version-id"
) # Should be active version ID
# Verify single StoreAgent lookup
mock_store_agent.return_value.find_first.assert_called_once_with(
# Verify mocks called correctly - now expecting 2 calls
assert mock_store_agent.return_value.find_first.call_count == 2
# Check the specific calls
calls = mock_store_agent.return_value.find_first.call_args_list
assert calls[0] == mocker.call(
where={"creator_username": "creator", "slug": "test-agent"}
)
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creator(mocker):
async def test_get_store_creator_details(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
name="Test Creator",
@@ -138,7 +202,7 @@ async def test_get_store_creator(mocker):
mock_creator.return_value.find_unique.return_value = mock_creator_data
# Call function
result = await db.get_store_creator("creator")
result = await db.get_store_creator_details("creator")
# Verify results
assert result.username == "creator"
@@ -154,110 +218,61 @@ async def test_get_store_creator(mocker):
@pytest.mark.asyncio(loop_scope="session")
async def test_create_store_submission(mocker):
now = datetime.now()
# Mock agent graph (with no pending submissions) and user with profile
mock_profile = prisma.models.Profile(
id="profile-id",
userId="user-id",
name="Test User",
username="testuser",
description="Test",
isFeatured=False,
links=[],
createdAt=now,
updatedAt=now,
)
mock_user = prisma.models.User(
id="user-id",
email="test@example.com",
createdAt=now,
updatedAt=now,
Profile=[mock_profile],
emailVerified=True,
metadata="{}", # type: ignore[reportArgumentType]
integrations="",
maxEmailsPerDay=1,
notifyOnAgentRun=True,
notifyOnZeroBalance=True,
notifyOnLowBalance=True,
notifyOnBlockExecutionFailed=True,
notifyOnContinuousAgentError=True,
notifyOnDailySummary=True,
notifyOnWeeklySummary=True,
notifyOnMonthlySummary=True,
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
)
# Mock data
mock_agent = prisma.models.AgentGraph(
id="agent-id",
version=1,
userId="user-id",
createdAt=now,
createdAt=datetime.now(),
isActive=True,
StoreListingVersions=[],
User=mock_user,
)
# Mock the created StoreListingVersion (returned by create)
mock_store_listing_obj = prisma.models.StoreListing(
mock_listing = prisma.models.StoreListing(
id="listing-id",
createdAt=now,
updatedAt=now,
createdAt=datetime.now(),
updatedAt=datetime.now(),
isDeleted=False,
hasApprovedVersion=False,
slug="test-agent",
agentGraphId="agent-id",
owningUserId="user-id",
useForOnboarding=False,
)
mock_version = prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=now,
updatedAt=now,
subHeading="",
imageUrls=[],
categories=[],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
submittedAt=now,
StoreListing=mock_store_listing_obj,
owningUserId="user-id",
Versions=[
prisma.models.StoreListingVersion(
id="version-id",
agentGraphId="agent-id",
agentGraphVersion=1,
name="Test Agent",
description="Test description",
createdAt=datetime.now(),
updatedAt=datetime.now(),
subHeading="Test heading",
imageUrls=["image.jpg"],
categories=["test"],
isFeatured=False,
isDeleted=False,
version=1,
storeListingId="listing-id",
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isAvailable=True,
)
],
useForOnboarding=False,
)
# Mock prisma calls
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock transaction context manager
mock_tx = mocker.MagicMock()
mocker.patch(
"backend.api.features.store.db.transaction",
return_value=mocker.AsyncMock(
__aenter__=mocker.AsyncMock(return_value=mock_tx),
__aexit__=mocker.AsyncMock(return_value=False),
),
)
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
# Call function
result = await db.create_store_submission(
user_id="user-id",
graph_id="agent-id",
graph_version=1,
agent_id="agent-id",
agent_version=1,
slug="test-agent",
name="Test Agent",
description="Test description",
@@ -266,11 +281,11 @@ async def test_create_store_submission(mocker):
# Verify results
assert result.name == "Test Agent"
assert result.description == "Test description"
assert result.listing_version_id == "version-id"
assert result.store_listing_version_id == "version-id"
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_slv.return_value.create.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@@ -303,6 +318,7 @@ async def test_update_profile(mocker):
description="Test description",
links=["link1"],
avatar_url="avatar.jpg",
is_featured=False,
)
# Call function
@@ -373,7 +389,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
creators=["creator1'; DROP TABLE Users; --", "creator2"],
category="AI'; DELETE FROM StoreAgent; --",
featured=True,
sorted_by=db.StoreAgentsSortOptions.RATING,
sorted_by="rating",
page=1,
page_size=20,
)

View File

@@ -15,7 +15,6 @@ from prisma.enums import ContentType
from tiktoken import encoding_for_model
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
from backend.blocks import get_blocks
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
from backend.util.clients import get_openai_client
from backend.util.json import dumps
@@ -663,6 +662,8 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
)
current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK:
from backend.blocks import get_blocks
current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION:
# Use DocumentationHandler to get section-based content IDs

View File

@@ -57,6 +57,12 @@ class StoreError(ValueError):
pass
class AgentNotFoundError(NotFoundError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(NotFoundError):
"""Raised when a creator is not found"""

View File

@@ -31,10 +31,12 @@ logger = logging.getLogger(__name__)
def tokenize(text: str) -> list[str]:
"""Tokenize text for BM25."""
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
if not text:
return []
return re.findall(r"\b\w+\b", text.lower())
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r"\b\w+\b", text.lower())
return tokens
def bm25_rerank(
@@ -566,7 +568,7 @@ async def hybrid_search(
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa.listing_version_id
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
@@ -580,7 +582,7 @@ async def hybrid_search(
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa.listing_version_id
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
@@ -603,7 +605,7 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa.graph_id,
sa."agentGraphId",
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -625,9 +627,9 @@ async def hybrid_search(
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa.listing_version_id
ON c."storeListingVersionId" = sa."storeListingVersionId"
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa.listing_version_id = uce."contentId"
ON sa."storeListingVersionId" = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
@@ -663,7 +665,7 @@ async def hybrid_search(
featured,
is_available,
updated_at,
graph_id,
"agentGraphId",
searchable_text,
semantic_score,
lexical_score,

View File

@@ -14,27 +14,9 @@ from backend.api.features.store.hybrid_search import (
HybridSearchWeights,
UnifiedSearchWeights,
hybrid_search,
tokenize,
unified_hybrid_search,
)
# ---------------------------------------------------------------------------
# tokenize (BM25)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"input_text, expected",
[
("AITextGeneratorBlock", ["aitextgeneratorblock"]),
("hello world", ["hello", "world"]),
("", []),
("HTTPRequest", ["httprequest"]),
],
)
def test_tokenize(input_text: str, expected: list[str]):
assert tokenize(input_text) == expected
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration

View File

@@ -1,14 +1,11 @@
import datetime
from typing import TYPE_CHECKING, List, Self
from typing import List
import prisma.enums
import pydantic
from backend.util.models import Pagination
if TYPE_CHECKING:
import prisma.models
class ChangelogEntry(pydantic.BaseModel):
version: str
@@ -16,9 +13,9 @@ class ChangelogEntry(pydantic.BaseModel):
date: datetime.datetime
class MyUnpublishedAgent(pydantic.BaseModel):
graph_id: str
graph_version: int
class MyAgent(pydantic.BaseModel):
agent_id: str
agent_version: int
agent_name: str
agent_image: str | None = None
description: str
@@ -26,8 +23,8 @@ class MyUnpublishedAgent(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
agents: list[MyUnpublishedAgent]
class MyAgentsResponse(pydantic.BaseModel):
agents: list[MyAgent]
pagination: Pagination
@@ -43,21 +40,6 @@ class StoreAgent(pydantic.BaseModel):
rating: float
agent_graph_id: str
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
return cls(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.graph_id,
)
class StoreAgentsResponse(pydantic.BaseModel):
agents: list[StoreAgent]
@@ -80,192 +62,81 @@ class StoreAgentDetails(pydantic.BaseModel):
runs: int
rating: float
versions: list[str]
graph_id: str
graph_versions: list[str]
agentGraphVersions: list[str]
agentGraphId: str
last_updated: datetime.datetime
recommended_schedule_cron: str | None = None
active_version_id: str
has_approved_version: bool
active_version_id: str | None = None
has_approved_version: bool = False
# Optional changelog data when include_changelog=True
changelog: list[ChangelogEntry] | None = None
@classmethod
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
return cls(
store_listing_version_id=agent.listing_version_id,
slug=agent.slug,
agent_name=agent.agent_name,
agent_video=agent.agent_video or "",
agent_output_demo=agent.agent_output_demo or "",
agent_image=agent.agent_image,
creator=agent.creator_username or "",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
categories=agent.categories,
runs=agent.runs,
rating=agent.rating,
versions=agent.versions,
graph_id=agent.graph_id,
graph_versions=agent.graph_versions,
last_updated=agent.updated_at,
recommended_schedule_cron=agent.recommended_schedule_cron,
active_version_id=agent.listing_version_id,
has_approved_version=True, # StoreAgent view only has approved agents
)
class Profile(pydantic.BaseModel):
"""Marketplace user profile (only attributes that the user can update)"""
username: str
class Creator(pydantic.BaseModel):
name: str
username: str
description: str
avatar_url: str | None
links: list[str]
class ProfileDetails(Profile):
"""Marketplace user profile (including read-only fields)"""
is_featured: bool
@classmethod
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
return cls(
name=profile.name,
username=profile.username,
avatar_url=profile.avatarUrl,
description=profile.description,
links=profile.links,
is_featured=profile.isFeatured,
)
class CreatorDetails(ProfileDetails):
"""Marketplace creator profile details, including aggregated stats"""
avatar_url: str
num_agents: int
agent_runs: int
agent_rating: float
top_categories: list[str]
@classmethod
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
return cls(
name=creator.name,
username=creator.username,
avatar_url=creator.avatar_url,
description=creator.description,
links=creator.links,
is_featured=creator.is_featured,
num_agents=creator.num_agents,
agent_runs=creator.agent_runs,
agent_rating=creator.agent_rating,
top_categories=creator.top_categories,
)
agent_runs: int
is_featured: bool
class CreatorsResponse(pydantic.BaseModel):
creators: List[CreatorDetails]
creators: List[Creator]
pagination: Pagination
class StoreSubmission(pydantic.BaseModel):
# From StoreListing:
listing_id: str
user_id: str
slug: str
class CreatorDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
agent_rating: float
agent_runs: int
top_categories: list[str]
# From StoreListingVersion:
listing_version_id: str
listing_version: int
graph_id: str
graph_version: int
class Profile(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str
is_featured: bool = False
class StoreSubmission(pydantic.BaseModel):
listing_id: str
agent_id: str
agent_version: int
name: str
sub_heading: str
slug: str
description: str
instructions: str | None
categories: list[str]
instructions: str | None = None
image_urls: list[str]
video_url: str | None
agent_output_demo_url: str | None
submitted_at: datetime.datetime | None
changes_summary: str | None
date_submitted: datetime.datetime
status: prisma.enums.SubmissionStatus
reviewed_at: datetime.datetime | None = None
runs: int
rating: float
store_listing_version_id: str | None = None
version: int | None = None # Actual version number from the database
reviewer_id: str | None = None
review_comments: str | None = None # External comments visible to creator
internal_comments: str | None = None # Private notes for admin use only
reviewed_at: datetime.datetime | None = None
changes_summary: str | None = None
# Aggregated from AgentGraphExecutions and StoreListingReviews:
run_count: int = 0
review_count: int = 0
review_avg_rating: float = 0.0
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
"""Construct from the StoreSubmission Prisma view."""
return cls(
listing_id=_sub.listing_id,
user_id=_sub.user_id,
slug=_sub.slug,
listing_version_id=_sub.listing_version_id,
listing_version=_sub.listing_version,
graph_id=_sub.graph_id,
graph_version=_sub.graph_version,
name=_sub.name,
sub_heading=_sub.sub_heading,
description=_sub.description,
instructions=_sub.instructions,
categories=_sub.categories,
image_urls=_sub.image_urls,
video_url=_sub.video_url,
agent_output_demo_url=_sub.agent_output_demo_url,
submitted_at=_sub.submitted_at,
changes_summary=_sub.changes_summary,
status=_sub.status,
reviewed_at=_sub.reviewed_at,
reviewer_id=_sub.reviewer_id,
review_comments=_sub.review_comments,
run_count=_sub.run_count,
review_count=_sub.review_count,
review_avg_rating=_sub.review_avg_rating,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
"""
Construct from the StoreListingVersion Prisma model (with StoreListing included)
"""
if not (_l := _lv.StoreListing):
raise ValueError("StoreListingVersion must have included StoreListing")
return cls(
listing_id=_l.id,
user_id=_l.owningUserId,
slug=_l.slug,
listing_version_id=_lv.id,
listing_version=_lv.version,
graph_id=_lv.agentGraphId,
graph_version=_lv.agentGraphVersion,
name=_lv.name,
sub_heading=_lv.subHeading,
description=_lv.description,
instructions=_lv.instructions,
categories=_lv.categories,
image_urls=_lv.imageUrls,
video_url=_lv.videoUrl,
agent_output_demo_url=_lv.agentOutputDemoUrl,
submitted_at=_lv.submittedAt,
changes_summary=_lv.changesSummary,
status=_lv.submissionStatus,
reviewed_at=_lv.reviewedAt,
reviewer_id=_lv.reviewerId,
review_comments=_lv.reviewComments,
)
# Additional fields for editing
video_url: str | None = None
agent_output_demo_url: str | None = None
categories: list[str] = []
class StoreSubmissionsResponse(pydantic.BaseModel):
@@ -273,12 +144,33 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
pagination: Pagination
class StoreListingWithVersions(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
slug: str
agent_id: str
agent_version: int
active_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmission | None = None
versions: list[StoreSubmission] = []
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersions]
pagination: Pagination
class StoreSubmissionRequest(pydantic.BaseModel):
graph_id: str = pydantic.Field(
..., min_length=1, description="Graph ID cannot be empty"
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
)
graph_version: int = pydantic.Field(
..., gt=0, description="Graph version must be greater than 0"
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
)
slug: str
name: str
@@ -306,42 +198,12 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
recommended_schedule_cron: str | None = None
class StoreSubmissionAdminView(StoreSubmission):
internal_comments: str | None # Private admin notes
@classmethod
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
return cls(
**StoreSubmission.from_db(_sub).model_dump(),
internal_comments=_sub.internal_comments,
)
@classmethod
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
return cls(
**StoreSubmission.from_listing_version(_lv).model_dump(),
internal_comments=_lv.internalComments,
)
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
"""A store listing with its version history"""
listing_id: str
graph_id: str
slug: str
active_listing_version_id: str | None = None
has_approved_version: bool = False
creator_email: str | None = None
latest_version: StoreSubmissionAdminView | None = None
versions: list[StoreSubmissionAdminView] = []
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
"""Response model for listings with version history"""
listings: list[StoreListingWithVersionsAdminView]
pagination: Pagination
class ProfileDetails(pydantic.BaseModel):
name: str
username: str
description: str
links: list[str]
avatar_url: str | None = None
class StoreReview(pydantic.BaseModel):

View File

@@ -0,0 +1,203 @@
import datetime
import prisma.enums
from . import model as store_model
def test_pagination():
pagination = store_model.Pagination(
total_items=100, total_pages=5, current_page=2, page_size=20
)
assert pagination.total_items == 100
assert pagination.total_pages == 5
assert pagination.current_page == 2
assert pagination.page_size == 20
def test_store_agent():
agent = store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
response = store_model.StoreAgentsResponse(
agents=[
store_model.StoreAgent(
slug="test-agent",
agent_name="Test Agent",
agent_image="test.jpg",
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.agents) == 1
assert response.pagination.total_items == 1
def test_store_agent_details():
details = store_model.StoreAgentDetails(
store_listing_version_id="version123",
slug="test-agent",
agent_name="Test Agent",
agent_video="video.mp4",
agent_output_demo="demo.mp4",
agent_image=["image1.jpg", "image2.jpg"],
creator="creator1",
creator_avatar="avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["cat1", "cat2"],
runs=50,
rating=4.5,
versions=["1.0", "2.0"],
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=datetime.datetime.now(),
)
assert details.slug == "test-agent"
assert len(details.agent_image) == 2
assert len(details.categories) == 2
assert len(details.versions) == 2
def test_creator():
creator = store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
assert creator.name == "Test Creator"
assert creator.num_agents == 5
def test_creators_response():
response = store_model.CreatorsResponse(
creators=[
store_model.Creator(
agent_rating=4.8,
agent_runs=1000,
name="Test Creator",
username="creator1",
description="Test description",
avatar_url="avatar.jpg",
num_agents=5,
is_featured=False,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.creators) == 1
assert response.pagination.total_items == 1
def test_creator_details():
details = store_model.CreatorDetails(
name="Test Creator",
username="creator1",
description="Test description",
links=["link1.com", "link2.com"],
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["cat1", "cat2"],
)
assert details.name == "Test Creator"
assert len(details.links) == 2
assert details.agent_rating == 4.8
assert len(details.top_categories) == 2
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg", "image2.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
assert submission.name == "Test Agent"
assert len(submission.image_urls) == 2
assert submission.status == prisma.enums.SubmissionStatus.PENDING
def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
name="Test Agent",
slug="test-agent",
description="Test description",
image_urls=["image1.jpg"],
date_submitted=datetime.datetime(2023, 1, 1),
status=prisma.enums.SubmissionStatus.PENDING,
runs=50,
rating=4.5,
)
],
pagination=store_model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=20
),
)
assert len(response.submissions) == 1
assert response.pagination.total_items == 1
def test_store_submission_request():
request = store_model.StoreSubmissionRequest(
agent_id="agent123",
agent_version=1,
slug="test-agent",
name="Test Agent",
sub_heading="Test subheading",
video_url="video.mp4",
image_urls=["image1.jpg", "image2.jpg"],
description="Test description",
categories=["cat1", "cat2"],
)
assert request.agent_id == "agent123"
assert request.agent_version == 1
assert len(request.image_urls) == 2
assert len(request.categories) == 2

View File

@@ -1,16 +1,16 @@
import logging
import tempfile
import typing
import urllib.parse
from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
from fastapi import Query, Security
from pydantic import BaseModel
import backend.data.graph
import backend.util.json
from backend.util.exceptions import NotFoundError
from backend.util.models import Pagination
from . import cache as store_cache
@@ -34,15 +34,22 @@ router = fastapi.APIRouter()
"/profile",
summary="Get user profile",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.ProfileDetails,
)
async def get_profile(
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Get the profile details for the authenticated user."""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Get the profile details for the authenticated user.
Cached for 1 hour per user.
"""
profile = await store_db.get_user_profile(user_id)
if profile is None:
raise NotFoundError("User does not have a profile yet")
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Profile not found"},
)
return profile
@@ -50,17 +57,98 @@ async def get_profile(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.CreatorDetails,
)
async def update_or_create_profile(
profile: store_model.Profile,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.ProfileDetails:
"""Update the store profile for the authenticated user."""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Update the store profile for the authenticated user.
Args:
profile (Profile): The updated profile details
user_id (str): ID of the authenticated user
Returns:
CreatorDetails: The updated profile
Raises:
HTTPException: If there is an error updating the profile
"""
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
return updated_profile
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
response_model=store_model.StoreAgentsResponse,
)
async def get_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
creator (str | None, optional): Filter agents by creator username. Defaults to None.
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
category (str | None, optional): Filter agents by category. Defaults to None.
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of agents per page. Defaults to 20.
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
Raises:
HTTPException: If page or page_size are less than 1
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
##############################################
############### Search Endpoints #############
##############################################
@@ -70,30 +158,60 @@ async def update_or_create_profile(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[prisma.enums.ContentType] | None = Query(
content_types: list[str] | None = fastapi.Query(
default=None,
description="Content types to search. If not specified, searches all.",
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
user_id: str | None = Security(
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
) -> store_model.UnifiedSearchResponse:
):
"""
Search across all content types (marketplace agents, blocks, documentation)
using hybrid search.
Search across all content types (store agents, blocks, documentation) using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_types,
content_types=content_type_enums,
user_id=user_id,
page=page,
page_size=page_size,
@@ -127,69 +245,22 @@ async def unified_search(
)
##############################################
############### Agent Endpoints ##############
##############################################
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
)
async def get_agents(
featured: bool = Query(
default=False, description="Filter to only show featured agents"
),
creator: str | None = Query(
default=None, description="Filter agents by creator username"
),
category: str | None = Query(default=None, description="Filter agents by category"),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
default=None,
description="Property to sort results by. Ignored if search_query is provided.",
),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the marketplace,
with optional filtering and sorting.
Used for:
- Home Page Featured Agents
- Home Page Top Agents
- Search Results
- Agent Details - Other Agents By Creator
- Agent Details - Similar Agents
- Creator Details - Agents By Creator
"""
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",
tags=["store", "public"],
response_model=store_model.StoreAgentDetails,
)
async def get_agent_by_name(
async def get_agent(
username: str,
agent_name: str,
include_changelog: bool = Query(default=False),
) -> store_model.StoreAgentDetails:
"""Get details of a marketplace agent"""
include_changelog: bool = fastapi.Query(default=False),
):
"""
This is only used on the AgentDetails Page.
It returns the store listing agents details.
"""
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
@@ -199,79 +270,76 @@ async def get_agent_by_name(
return agent
@router.get(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""
Get Agent Graph from Store Listing Version ID.
"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
"""
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.post(
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreReview,
)
async def post_user_review_for_agent(
async def create_review(
username: str,
agent_name: str,
review: store_model.StoreReviewCreate,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreReview:
"""Post a user review on a marketplace agent listing"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a review for a store agent.
Args:
username: Creator's username
agent_name: Name/slug of the agent
review: Review details including score and optional comments
user_id: ID of authenticated user creating the review
Returns:
The created review
"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await store_db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
@router.get(
"/listings/versions/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_agent_by_listing_version(
store_listing_version_id: str,
) -> store_model.StoreAgentDetails:
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
return agent
@router.get(
"/listings/versions/{store_listing_version_id}/graph",
summary="Get agent graph",
tags=["store"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
)
async def get_graph_meta_by_store_listing_version_id(
store_listing_version_id: str,
) -> backend.data.graph.GraphModelWithoutNodes:
"""Get outline of graph belonging to a specific marketplace listing version"""
graph = await store_db.get_available_graph(store_listing_version_id)
return graph
@router.get(
"/listings/versions/{store_listing_version_id}/graph/download",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str,
) -> fastapi.responses.Response:
"""Download agent graph file for a specific marketplace listing version"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
return fastapi.responses.Response(
content=backend.util.json.dumps(graph_data),
media_type="application/json",
headers={
"Content-Disposition": f'attachment; filename="{file_name}"',
},
)
##############################################
############# Creator Endpoints #############
##############################################
@@ -281,19 +349,37 @@ async def download_agent_file(
"/creators",
summary="List store creators",
tags=["store", "public"],
response_model=store_model.CreatorsResponse,
)
async def get_creators(
featured: bool = Query(
default=False, description="Filter to only show featured creators"
),
search_query: str | None = Query(
default=None, description="Literal + semantic search on names and descriptions"
),
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.CreatorsResponse:
"""List or search marketplace creators"""
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
):
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
---
To support this functionality we need:
- featured: bool - to limit the list to just featured agents
- search_query: str - vector search based on the creators profile description.
- sorted_by: [agent_rating, agent_runs] -
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
@@ -305,12 +391,18 @@ async def get_creators(
@router.get(
"/creators/{username}",
"/creator/{username}",
summary="Get creator details",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)
async def get_creator(username: str) -> store_model.CreatorDetails:
"""Get details on a marketplace creator"""
async def get_creator(
username: str,
):
"""
Get the details of a creator.
- Creator Details Page
"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
@@ -322,17 +414,20 @@ async def get_creator(username: str) -> store_model.CreatorDetails:
@router.get(
"/my-unpublished-agents",
"/myagents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.MyAgentsResponse,
)
async def get_my_unpublished_agents(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.MyUnpublishedAgentsResponse:
"""List the authenticated user's unpublished agents"""
async def get_my_agents(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
"""
Get user's own agents.
"""
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
return agents
@@ -341,17 +436,28 @@ async def get_my_unpublished_agents(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=bool,
)
async def delete_submission(
submission_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> bool:
"""Delete a marketplace listing submission"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Delete a store listing submission.
Args:
user_id (str): ID of the authenticated user
submission_id (str): ID of the submission to be deleted
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
result = await store_db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
@@ -359,14 +465,37 @@ async def delete_submission(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmissionsResponse,
)
async def get_submissions(
user_id: str = Security(autogpt_libs.auth.get_user_id),
page: int = Query(ge=1, default=1),
page_size: int = Query(ge=1, default=20),
) -> store_model.StoreSubmissionsResponse:
"""List the authenticated user's marketplace listing submissions"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = 20,
):
"""
Get a paginated list of store submissions for the authenticated user.
Args:
user_id (str): ID of the authenticated user
page (int, optional): Page number for pagination. Defaults to 1.
page_size (int, optional): Number of submissions per page. Defaults to 20.
Returns:
StoreListingsResponse: Paginated list of store submissions
Raises:
HTTPException: If page or page_size are less than 1
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
listings = await store_db.get_store_submissions(
user_id=user_id,
page=page,
@@ -379,17 +508,30 @@ async def get_submissions(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
)
async def create_submission(
submission_request: store_model.StoreSubmissionRequest,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Submit a new marketplace listing for review"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Create a new store listing submission.
Args:
submission_request (StoreSubmissionRequest): The submission details
user_id (str): ID of the authenticated user submitting the listing
Returns:
StoreSubmission: The created store submission
Raises:
HTTPException: If there is an error creating the submission
"""
result = await store_db.create_store_submission(
user_id=user_id,
graph_id=submission_request.graph_id,
graph_version=submission_request.graph_version,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
@@ -402,6 +544,7 @@ async def create_submission(
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -409,14 +552,28 @@ async def create_submission(
"/submissions/{store_listing_version_id}",
summary="Edit store submission",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
response_model=store_model.StoreSubmission,
)
async def edit_submission(
store_listing_version_id: str,
submission_request: store_model.StoreSubmissionEditRequest,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmission:
"""Update a pending marketplace listing submission"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Edit an existing store listing submission.
Args:
store_listing_version_id (str): ID of the store listing version to edit
submission_request (StoreSubmissionRequest): The updated submission details
user_id (str): ID of the authenticated user editing the listing
Returns:
StoreSubmission: The updated store submission
Raises:
HTTPException: If there is an error editing the submission
"""
result = await store_db.edit_store_submission(
user_id=user_id,
store_listing_version_id=store_listing_version_id,
@@ -431,6 +588,7 @@ async def edit_submission(
changes_summary=submission_request.changes_summary,
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
@@ -438,61 +596,115 @@ async def edit_submission(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def upload_submission_media(
file: fastapi.UploadFile,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> str:
"""Upload media for a marketplace listing submission"""
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
):
"""
Upload media (images/videos) for a store listing submission.
Args:
file (UploadFile): The media file to upload
user_id (str): ID of the authenticated user uploading the media
Returns:
str: URL of the uploaded media file
Raises:
HTTPException: If there is an error uploading the media
"""
media_url = await store_media.upload_media(user_id=user_id, file=file)
return media_url
class ImageURLResponse(BaseModel):
image_url: str
@router.post(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[Security(autogpt_libs.auth.requires_user)],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
)
async def generate_image(
graph_id: str,
user_id: str = Security(autogpt_libs.auth.get_user_id),
) -> ImageURLResponse:
agent_id: str,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> fastapi.responses.Response:
"""
Generate an image for a marketplace listing submission based on the properties
of a given graph.
Generate an image for a store listing submission.
Args:
agent_id (str): ID of the agent to generate an image for
user_id (str): ID of the authenticated user
Returns:
JSONResponse: JSON containing the URL of the generated image
"""
graph = await backend.data.graph.get_graph(
graph_id=graph_id, version=None, user_id=user_id
agent = await backend.data.graph.get_graph(
graph_id=agent_id, version=None, user_id=user_id
)
if not graph:
raise NotFoundError(f"Agent graph #{graph_id} not found")
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{graph_id}.jpeg"
filename = f"agent_{agent_id}.jpeg"
existing_url = await store_media.check_media_exists(user_id, filename)
if existing_url:
logger.info(f"Using existing image for agent graph {graph_id}")
return ImageURLResponse(image_url=existing_url)
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(agent=graph)
image = await store_image_gen.generate_agent_image(agent=agent)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
return ImageURLResponse(image_url=image_url)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
@router.get(
"/download/agents/{store_listing_version_id}",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await store_db.get_agent(store_listing_version_id)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
##############################################

View File

@@ -8,8 +8,6 @@ import pytest
import pytest_mock
from pytest_snapshot.plugin import Snapshot
from backend.api.features.store.db import StoreAgentsSortOptions
from . import model as store_model
from . import routes as store_routes
@@ -198,7 +196,7 @@ def test_get_agents_sorted(
mock_db_call.assert_called_once_with(
featured=False,
creators=None,
sorted_by=StoreAgentsSortOptions.RUNS,
sorted_by="runs",
search_query=None,
category=None,
page=1,
@@ -382,11 +380,9 @@ def test_get_agent_details(
runs=100,
rating=4.5,
versions=["1.0.0", "1.1.0"],
graph_versions=["1", "2"],
graph_id="test-graph-id",
agentGraphVersions=["1", "2"],
agentGraphId="test-graph-id",
last_updated=FIXED_NOW,
active_version_id="test-version-id",
has_approved_version=True,
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
mock_db_call.return_value = mocked_value
@@ -439,17 +435,15 @@ def test_get_creators_pagination(
) -> None:
mocked_value = store_model.CreatorsResponse(
creators=[
store_model.CreatorDetails(
store_model.Creator(
name=f"Creator {i}",
username=f"creator{i}",
avatar_url=f"avatar{i}.jpg",
description=f"Creator {i} description",
links=[f"user{i}.link.com"],
is_featured=False,
avatar_url=f"avatar{i}.jpg",
num_agents=1,
agent_runs=100,
agent_rating=4.5,
top_categories=["cat1", "cat2", "cat3"],
agent_runs=100,
is_featured=False,
)
for i in range(5)
],
@@ -502,19 +496,19 @@ def test_get_creator_details(
mocked_value = store_model.CreatorDetails(
name="Test User",
username="creator1",
avatar_url="avatar.jpg",
description="Test creator description",
links=["link1.com", "link2.com"],
is_featured=True,
num_agents=5,
agent_runs=1000,
avatar_url="avatar.jpg",
agent_rating=4.8,
agent_runs=1000,
top_categories=["category1", "category2"],
)
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
mock_db_call = mocker.patch(
"backend.api.features.store.db.get_store_creator_details"
)
mock_db_call.return_value = mocked_value
response = client.get("/creators/creator1")
response = client.get("/creator/creator1")
assert response.status_code == 200
data = store_model.CreatorDetails.model_validate(response.json())
@@ -534,26 +528,19 @@ def test_get_submissions_success(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
user_id="test-user-id",
slug="test-agent",
listing_version_id="test-version-id",
listing_version=1,
graph_id="test-agent-id",
graph_version=1,
name="Test Agent",
sub_heading="Test agent subheading",
description="Test agent description",
instructions="Click the button!",
categories=["test-category"],
image_urls=["test.jpg"],
video_url="test.mp4",
agent_output_demo_url="demo_video.mp4",
submitted_at=FIXED_NOW,
changes_summary="Initial Submission",
date_submitted=FIXED_NOW,
status=prisma.enums.SubmissionStatus.APPROVED,
run_count=50,
review_count=5,
review_avg_rating=4.2,
runs=50,
rating=4.2,
agent_id="test-agent-id",
agent_version=1,
sub_heading="Test agent subheading",
slug="test-agent",
video_url="test.mp4",
categories=["test-category"],
)
],
pagination=store_model.Pagination(

View File

@@ -11,7 +11,6 @@ import pytest
from backend.util.models import Pagination
from . import cache as store_cache
from .db import StoreAgentsSortOptions
from .model import StoreAgent, StoreAgentsResponse
@@ -216,7 +215,7 @@ class TestCacheDeletion:
await store_cache._get_cached_store_agents(
featured=True,
creator="testuser",
sorted_by=StoreAgentsSortOptions.RATING,
sorted_by="rating",
search_query="AI assistant",
category="productivity",
page=2,
@@ -228,7 +227,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by=StoreAgentsSortOptions.RATING,
sorted_by="rating",
search_query="AI assistant",
category="productivity",
page=2,
@@ -240,7 +239,7 @@ class TestCacheDeletion:
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by=StoreAgentsSortOptions.RATING,
sorted_by="rating",
search_query="AI assistant",
category="productivity",
page=2,

View File

@@ -1,5 +0,0 @@
"""Backward-compatibility shim — ``split_camelcase`` now lives in backend.util.text."""
from backend.util.text import split_camelcase # noqa: F401
__all__ = ["split_camelcase"]

View File

@@ -1,49 +0,0 @@
"""Tests for split_camelcase (now in backend.util.text)."""
import pytest
from backend.util.text import split_camelcase
# ---------------------------------------------------------------------------
# split_camelcase
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"input_text, expected",
[
("AITextGeneratorBlock", "AI Text Generator Block"),
("HTTPRequestBlock", "HTTP Request Block"),
("simpleWord", "simple Word"),
("already spaced", "already spaced"),
("XMLParser", "XML Parser"),
("getHTTPResponse", "get HTTP Response"),
("Block", "Block"),
("", ""),
("OAuth2Block", "OAuth2 Block"),
("IOError", "IO Error"),
("getHTTPSResponse", "get HTTPS Response"),
# Known limitation: single-letter uppercase prefixes are NOT split.
# "ABlock" stays "ABlock" because the algorithm requires the left
# part of an uppercase run to retain at least 2 uppercase chars.
("ABlock", "ABlock"),
# Digit-to-uppercase transitions
("Base64Encoder", "Base64 Encoder"),
("UTF8Decoder", "UTF8 Decoder"),
# Pure digits — no camelCase boundaries to split
("123", "123"),
# Known limitation: single-letter uppercase segments after digits
# are not split from the following word. "3D" is only 1 uppercase
# char so the uppercase-run rule cannot fire, producing "3 DRenderer"
# rather than the ideal "3D Renderer".
("3DRenderer", "3 DRenderer"),
# Exception list — compound terms that should stay together
("YouTubeBlock", "YouTube Block"),
("OpenAIBlock", "OpenAI Block"),
("AutoGPTAgent", "AutoGPT Agent"),
("GitHubIntegration", "GitHub Integration"),
("LinkedInBlock", "LinkedIn Block"),
],
)
def test_split_camelcase(input_text: str, expected: str):
assert split_camelcase(input_text) == expected

View File

@@ -449,6 +449,7 @@ async def execute_graph_block(
async def upload_file(
user_id: Annotated[str, Security(get_user_id)],
file: UploadFile = File(...),
provider: str = "gcs",
expiration_hours: int = 24,
) -> UploadFileResponse:
"""
@@ -511,6 +512,7 @@ async def upload_file(
storage_path = await cloud_storage.store_file(
content=content,
filename=file_name,
provider=provider,
expiration_hours=expiration_hours,
user_id=user_id,
)

View File

@@ -515,6 +515,7 @@ async def test_upload_file_success(test_user_id: str):
result = await upload_file(
file=upload_file_mock,
user_id=test_user_id,
provider="gcs",
expiration_hours=24,
)
@@ -532,6 +533,7 @@ async def test_upload_file_success(test_user_id: str):
mock_handler.store_file.assert_called_once_with(
content=file_content,
filename="test.txt",
provider="gcs",
expiration_hours=24,
user_id=test_user_id,
)

View File

@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
class OnboardingNotificationPayload(NotificationPayload):
step: OnboardingStep | None
class CopilotCompletionPayload(NotificationPayload):
session_id: str
status: Literal["completed", "failed"]

View File

@@ -55,7 +55,6 @@ from backend.util.exceptions import (
MissingConfigError,
NotAuthorizedError,
NotFoundError,
PreconditionFailed,
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
@@ -276,7 +275,6 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
app.add_exception_handler(ValueError, handle_internal_http_error(400))
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
app.add_exception_handler(Exception, handle_internal_http_error(500))
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")

View File

@@ -418,8 +418,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -472,8 +470,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
@@ -624,7 +620,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
is_graph_execution: bool = True,
**kwargs,
) -> tuple[bool, BlockInput]:
"""
@@ -653,7 +648,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_version=graph_version,
block_name=self.name,
editable=True,
is_graph_execution=is_graph_execution,
)
if decision is None:

View File

@@ -1,33 +0,0 @@
"""
Shared configuration for all AgentMail blocks.
"""
from agentmail import AsyncAgentMail
from backend.sdk import APIKeyCredentials, ProviderBuilder, SecretStr
agent_mail = (
ProviderBuilder("agent_mail")
.with_api_key("AGENTMAIL_API_KEY", "AgentMail API Key")
.build()
)
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="agent_mail",
title="Mock AgentMail API Key",
api_key=SecretStr("mock-agentmail-api-key"),
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"id": TEST_CREDENTIALS.id,
"provider": TEST_CREDENTIALS.provider,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def _client(credentials: APIKeyCredentials) -> AsyncAgentMail:
"""Create an AsyncAgentMail client from credentials."""
return AsyncAgentMail(api_key=credentials.api_key.get_secret_value())

View File

@@ -1,211 +0,0 @@
"""
AgentMail Attachment blocks — download file attachments from messages and threads.
Attachments are files associated with messages (PDFs, CSVs, images, etc.).
To send attachments, include them in the attachments parameter when using
AgentMailSendMessageBlock or AgentMailReplyToMessageBlock.
To download, first get the attachment_id from a message's attachments array,
then use these blocks to retrieve the file content as base64.
"""
import base64
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class AgentMailGetMessageAttachmentBlock(Block):
"""
Download a file attachment from a specific email message.
Retrieves the raw file content and returns it as base64-encoded data.
First get the attachment_id from a message object's attachments array,
then use this block to download the file.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the message belongs to"
)
message_id: str = SchemaField(
description="Message ID containing the attachment"
)
attachment_id: str = SchemaField(
description="Attachment ID to download (from the message's attachments array)"
)
class Output(BlockSchemaOutput):
content_base64: str = SchemaField(
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
)
attachment_id: str = SchemaField(
description="The attachment ID that was downloaded"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="a283ffc4-8087-4c3d-9135-8f26b86742ec",
description="Download a file attachment from an email message. Returns base64-encoded file content.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"message_id": "test-msg",
"attachment_id": "test-attach",
},
test_output=[
("content_base64", "dGVzdA=="),
("attachment_id", "test-attach"),
],
test_mock={
"get_attachment": lambda *a, **kw: b"test",
},
)
@staticmethod
async def get_attachment(
credentials: APIKeyCredentials,
inbox_id: str,
message_id: str,
attachment_id: str,
):
client = _client(credentials)
return await client.inboxes.messages.get_attachment(
inbox_id=inbox_id,
message_id=message_id,
attachment_id=attachment_id,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
data = await self.get_attachment(
credentials=credentials,
inbox_id=input_data.inbox_id,
message_id=input_data.message_id,
attachment_id=input_data.attachment_id,
)
if isinstance(data, bytes):
encoded = base64.b64encode(data).decode()
elif isinstance(data, str):
encoded = base64.b64encode(data.encode("utf-8")).decode()
else:
raise TypeError(
f"Unexpected attachment data type: {type(data).__name__}"
)
yield "content_base64", encoded
yield "attachment_id", input_data.attachment_id
except Exception as e:
yield "error", str(e)
class AgentMailGetThreadAttachmentBlock(Block):
"""
Download a file attachment from a conversation thread.
Same as GetMessageAttachment but looks up by thread ID instead of
message ID. Useful when you know the thread but not the specific
message containing the attachment.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the thread belongs to"
)
thread_id: str = SchemaField(description="Thread ID containing the attachment")
attachment_id: str = SchemaField(
description="Attachment ID to download (from a message's attachments array within the thread)"
)
class Output(BlockSchemaOutput):
content_base64: str = SchemaField(
description="File content encoded as a base64 string. Decode with base64.b64decode() to get raw bytes."
)
attachment_id: str = SchemaField(
description="The attachment ID that was downloaded"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="06b6a4c4-9d71-4992-9e9c-cf3b352763b5",
description="Download a file attachment from a conversation thread. Returns base64-encoded file content.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"thread_id": "test-thread",
"attachment_id": "test-attach",
},
test_output=[
("content_base64", "dGVzdA=="),
("attachment_id", "test-attach"),
],
test_mock={
"get_attachment": lambda *a, **kw: b"test",
},
)
@staticmethod
async def get_attachment(
credentials: APIKeyCredentials,
inbox_id: str,
thread_id: str,
attachment_id: str,
):
client = _client(credentials)
return await client.inboxes.threads.get_attachment(
inbox_id=inbox_id,
thread_id=thread_id,
attachment_id=attachment_id,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
data = await self.get_attachment(
credentials=credentials,
inbox_id=input_data.inbox_id,
thread_id=input_data.thread_id,
attachment_id=input_data.attachment_id,
)
if isinstance(data, bytes):
encoded = base64.b64encode(data).decode()
elif isinstance(data, str):
encoded = base64.b64encode(data.encode("utf-8")).decode()
else:
raise TypeError(
f"Unexpected attachment data type: {type(data).__name__}"
)
yield "content_base64", encoded
yield "attachment_id", input_data.attachment_id
except Exception as e:
yield "error", str(e)

View File

@@ -1,678 +0,0 @@
"""
AgentMail Draft blocks — create, get, list, update, send, and delete drafts.
A Draft is an unsent message that can be reviewed, edited, and sent later.
Drafts enable human-in-the-loop review, scheduled sending (via send_at),
and complex multi-step email composition workflows.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class AgentMailCreateDraftBlock(Block):
"""
Create a draft email in an AgentMail inbox for review or scheduled sending.
Drafts let agents prepare emails without sending immediately. Use send_at
to schedule automatic sending at a future time (ISO 8601 format).
Scheduled drafts are auto-labeled 'scheduled' and can be cancelled by
deleting the draft.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to create the draft in"
)
to: list[str] = SchemaField(
description="Recipient email addresses (e.g. ['user@example.com'])"
)
subject: str = SchemaField(description="Email subject line", default="")
text: str = SchemaField(description="Plain text body of the draft", default="")
html: str = SchemaField(
description="Rich HTML body of the draft", default="", advanced=True
)
cc: list[str] = SchemaField(
description="CC recipient email addresses",
default_factory=list,
advanced=True,
)
bcc: list[str] = SchemaField(
description="BCC recipient email addresses",
default_factory=list,
advanced=True,
)
in_reply_to: str = SchemaField(
description="Message ID this draft replies to, for threading follow-up drafts",
default="",
advanced=True,
)
send_at: str = SchemaField(
description="Schedule automatic sending at this ISO 8601 datetime (e.g. '2025-01-15T09:00:00Z'). Leave empty for manual send.",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
draft_id: str = SchemaField(
description="Unique identifier of the created draft"
)
send_status: str = SchemaField(
description="'scheduled' if send_at was set, empty otherwise. Values: scheduled, sending, failed.",
default="",
)
result: dict = SchemaField(
description="Complete draft object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="25ac9086-69fd-48b8-b910-9dbe04b8f3bd",
description="Create a draft email for review or scheduled sending. Use send_at for automatic future delivery.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"to": ["user@example.com"],
},
test_output=[
("draft_id", "mock-draft-id"),
("send_status", ""),
("result", dict),
],
test_mock={
"create_draft": lambda *a, **kw: type(
"Draft",
(),
{
"draft_id": "mock-draft-id",
"send_status": "",
"model_dump": lambda self: {"draft_id": "mock-draft-id"},
},
)(),
},
)
@staticmethod
async def create_draft(credentials: APIKeyCredentials, inbox_id: str, **params):
client = _client(credentials)
return await client.inboxes.drafts.create(inbox_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"to": input_data.to}
if input_data.subject:
params["subject"] = input_data.subject
if input_data.text:
params["text"] = input_data.text
if input_data.html:
params["html"] = input_data.html
if input_data.cc:
params["cc"] = input_data.cc
if input_data.bcc:
params["bcc"] = input_data.bcc
if input_data.in_reply_to:
params["in_reply_to"] = input_data.in_reply_to
if input_data.send_at:
params["send_at"] = input_data.send_at
draft = await self.create_draft(credentials, input_data.inbox_id, **params)
result = draft.model_dump()
yield "draft_id", draft.draft_id
yield "send_status", draft.send_status or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailGetDraftBlock(Block):
"""
Retrieve a specific draft from an AgentMail inbox.
Returns the draft contents including recipients, subject, body, and
scheduled send status. Use this to review a draft before approving it.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the draft belongs to"
)
draft_id: str = SchemaField(description="Draft ID to retrieve")
class Output(BlockSchemaOutput):
draft_id: str = SchemaField(description="Unique identifier of the draft")
subject: str = SchemaField(description="Draft subject line", default="")
send_status: str = SchemaField(
description="Scheduled send status: 'scheduled', 'sending', 'failed', or empty",
default="",
)
send_at: str = SchemaField(
description="Scheduled send time (ISO 8601) if set", default=""
)
result: dict = SchemaField(description="Complete draft object with all fields")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="8e57780d-dc25-43d4-a0f4-1f02877b09fb",
description="Retrieve a draft email to review its contents, recipients, and scheduled send status.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"draft_id": "test-draft",
},
test_output=[
("draft_id", "test-draft"),
("subject", ""),
("send_status", ""),
("send_at", ""),
("result", dict),
],
test_mock={
"get_draft": lambda *a, **kw: type(
"Draft",
(),
{
"draft_id": "test-draft",
"subject": "",
"send_status": "",
"send_at": "",
"model_dump": lambda self: {"draft_id": "test-draft"},
},
)(),
},
)
@staticmethod
async def get_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
client = _client(credentials)
return await client.inboxes.drafts.get(inbox_id=inbox_id, draft_id=draft_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
draft = await self.get_draft(
credentials, input_data.inbox_id, input_data.draft_id
)
result = draft.model_dump()
yield "draft_id", draft.draft_id
yield "subject", draft.subject or ""
yield "send_status", draft.send_status or ""
yield "send_at", draft.send_at or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailListDraftsBlock(Block):
"""
List all drafts in an AgentMail inbox with optional label filtering.
Use labels=['scheduled'] to find all drafts queued for future sending.
Useful for building approval dashboards or monitoring pending outreach.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to list drafts from"
)
limit: int = SchemaField(
description="Maximum number of drafts to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
labels: list[str] = SchemaField(
description="Filter drafts by labels (e.g. ['scheduled'] for pending sends)",
default_factory=list,
advanced=True,
)
class Output(BlockSchemaOutput):
drafts: list[dict] = SchemaField(
description="List of draft objects with subject, recipients, send_status, etc."
)
count: int = SchemaField(description="Number of drafts returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="e84883b7-7c39-4c5c-88e8-0a72b078ea63",
description="List drafts in an AgentMail inbox. Filter by labels=['scheduled'] to find pending sends.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
},
test_output=[
("drafts", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_drafts": lambda *a, **kw: type(
"Resp",
(),
{
"drafts": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_drafts(credentials: APIKeyCredentials, inbox_id: str, **params):
client = _client(credentials)
return await client.inboxes.drafts.list(inbox_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
if input_data.labels:
params["labels"] = input_data.labels
response = await self.list_drafts(
credentials, input_data.inbox_id, **params
)
drafts = [d.model_dump() for d in response.drafts]
yield "drafts", drafts
yield "count", response.count
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailUpdateDraftBlock(Block):
"""
Update an existing draft's content, recipients, or scheduled send time.
Use this to reschedule a draft (change send_at), modify recipients,
or edit the subject/body before sending. To cancel a scheduled send,
delete the draft instead.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the draft belongs to"
)
draft_id: str = SchemaField(description="Draft ID to update")
to: Optional[list[str]] = SchemaField(
description="Updated recipient email addresses (replaces existing list). Omit to keep current value.",
default=None,
)
subject: Optional[str] = SchemaField(
description="Updated subject line. Omit to keep current value.",
default=None,
)
text: Optional[str] = SchemaField(
description="Updated plain text body. Omit to keep current value.",
default=None,
)
html: Optional[str] = SchemaField(
description="Updated HTML body. Omit to keep current value.",
default=None,
advanced=True,
)
send_at: Optional[str] = SchemaField(
description="Reschedule: new ISO 8601 send time (e.g. '2025-01-20T14:00:00Z'). Omit to keep current value.",
default=None,
advanced=True,
)
class Output(BlockSchemaOutput):
draft_id: str = SchemaField(description="The updated draft ID")
send_status: str = SchemaField(description="Updated send status", default="")
result: dict = SchemaField(description="Complete updated draft object")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="351f6e51-695a-421a-9032-46a587b10336",
description="Update a draft's content, recipients, or scheduled send time. Use to reschedule or edit before sending.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"draft_id": "test-draft",
},
test_output=[
("draft_id", "test-draft"),
("send_status", ""),
("result", dict),
],
test_mock={
"update_draft": lambda *a, **kw: type(
"Draft",
(),
{
"draft_id": "test-draft",
"send_status": "",
"model_dump": lambda self: {"draft_id": "test-draft"},
},
)(),
},
)
@staticmethod
async def update_draft(
credentials: APIKeyCredentials, inbox_id: str, draft_id: str, **params
):
client = _client(credentials)
return await client.inboxes.drafts.update(
inbox_id=inbox_id, draft_id=draft_id, **params
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {}
if input_data.to is not None:
params["to"] = input_data.to
if input_data.subject is not None:
params["subject"] = input_data.subject
if input_data.text is not None:
params["text"] = input_data.text
if input_data.html is not None:
params["html"] = input_data.html
if input_data.send_at is not None:
params["send_at"] = input_data.send_at
draft = await self.update_draft(
credentials, input_data.inbox_id, input_data.draft_id, **params
)
result = draft.model_dump()
yield "draft_id", draft.draft_id
yield "send_status", draft.send_status or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailSendDraftBlock(Block):
"""
Send a draft immediately, converting it into a delivered message.
The draft is deleted after successful sending and becomes a regular
message with a message_id. Use this for human-in-the-loop approval
workflows: agent creates draft, human reviews, then this block sends it.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the draft belongs to"
)
draft_id: str = SchemaField(description="Draft ID to send now")
class Output(BlockSchemaOutput):
message_id: str = SchemaField(
description="Message ID of the now-sent email (draft is deleted)"
)
thread_id: str = SchemaField(
description="Thread ID the sent message belongs to"
)
result: dict = SchemaField(description="Complete sent message object")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="37c39e83-475d-4b3d-843a-d923d001b85a",
description="Send a draft immediately, converting it into a delivered message. The draft is deleted after sending.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"draft_id": "test-draft",
},
test_output=[
("message_id", "mock-msg-id"),
("thread_id", "mock-thread-id"),
("result", dict),
],
test_mock={
"send_draft": lambda *a, **kw: type(
"Msg",
(),
{
"message_id": "mock-msg-id",
"thread_id": "mock-thread-id",
"model_dump": lambda self: {"message_id": "mock-msg-id"},
},
)(),
},
)
@staticmethod
async def send_draft(credentials: APIKeyCredentials, inbox_id: str, draft_id: str):
client = _client(credentials)
return await client.inboxes.drafts.send(inbox_id=inbox_id, draft_id=draft_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
msg = await self.send_draft(
credentials, input_data.inbox_id, input_data.draft_id
)
result = msg.model_dump()
yield "message_id", msg.message_id
yield "thread_id", msg.thread_id or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailDeleteDraftBlock(Block):
"""
Delete a draft from an AgentMail inbox. Also cancels any scheduled send.
If the draft was scheduled with send_at, deleting it cancels the
scheduled delivery. This is the way to cancel a scheduled email.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the draft belongs to"
)
draft_id: str = SchemaField(
description="Draft ID to delete (also cancels scheduled sends)"
)
class Output(BlockSchemaOutput):
success: bool = SchemaField(
description="True if the draft was successfully deleted/cancelled"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="9023eb99-3e2f-4def-808b-d9c584b3d9e7",
description="Delete a draft or cancel a scheduled email. Removes the draft permanently.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"draft_id": "test-draft",
},
test_output=[("success", True)],
test_mock={
"delete_draft": lambda *a, **kw: None,
},
)
@staticmethod
async def delete_draft(
credentials: APIKeyCredentials, inbox_id: str, draft_id: str
):
client = _client(credentials)
await client.inboxes.drafts.delete(inbox_id=inbox_id, draft_id=draft_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
await self.delete_draft(
credentials, input_data.inbox_id, input_data.draft_id
)
yield "success", True
except Exception as e:
yield "error", str(e)
class AgentMailListOrgDraftsBlock(Block):
"""
List all drafts across every inbox in your organization.
Returns drafts from all inboxes in one query. Perfect for building
a central approval dashboard where a human supervisor can review
and approve any draft created by any agent.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
limit: int = SchemaField(
description="Maximum number of drafts to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
drafts: list[dict] = SchemaField(
description="List of draft objects from all inboxes in the organization"
)
count: int = SchemaField(description="Number of drafts returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="ed7558ae-3a07-45f5-af55-a25fe88c9971",
description="List all drafts across every inbox in your organization. Use for central approval dashboards.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
("drafts", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_org_drafts": lambda *a, **kw: type(
"Resp",
(),
{
"drafts": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_org_drafts(credentials: APIKeyCredentials, **params):
client = _client(credentials)
return await client.drafts.list(**params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
response = await self.list_org_drafts(credentials, **params)
drafts = [d.model_dump() for d in response.drafts]
yield "drafts", drafts
yield "count", response.count
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)

View File

@@ -1,414 +0,0 @@
"""
AgentMail Inbox blocks — create, get, list, update, and delete inboxes.
An Inbox is a fully programmable email account for AI agents. Each inbox gets
a unique email address and can send, receive, and manage emails via the
AgentMail API. You can create thousands of inboxes on demand.
"""
from agentmail.inboxes.types import CreateInboxRequest
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class AgentMailCreateInboxBlock(Block):
"""
Create a new email inbox for an AI agent via AgentMail.
Each inbox gets a unique email address (e.g. username@agentmail.to).
If username and domain are not provided, AgentMail auto-generates them.
Use custom domains by specifying the domain field.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
username: str = SchemaField(
description="Local part of the email address (e.g. 'support' for support@domain.com). Leave empty to auto-generate.",
default="",
advanced=False,
)
domain: str = SchemaField(
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
default="",
advanced=False,
)
display_name: str = SchemaField(
description="Friendly name shown in the 'From' field of sent emails (e.g. 'Support Agent')",
default="",
advanced=False,
)
class Output(BlockSchemaOutput):
inbox_id: str = SchemaField(
description="Unique identifier for the created inbox (also the email address)"
)
email_address: str = SchemaField(
description="Full email address of the inbox (e.g. support@agentmail.to)"
)
result: dict = SchemaField(
description="Complete inbox object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="7a8ac219-c6ec-4eec-a828-81af283ce04c",
description="Create a new email inbox for an AI agent via AgentMail. Each inbox gets a unique address and can send/receive emails.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
("inbox_id", "mock-inbox-id"),
("email_address", "mock-inbox-id"),
("result", dict),
],
test_mock={
"create_inbox": lambda *a, **kw: type(
"Inbox",
(),
{
"inbox_id": "mock-inbox-id",
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
},
)(),
},
)
@staticmethod
async def create_inbox(credentials: APIKeyCredentials, **params):
client = _client(credentials)
return await client.inboxes.create(request=CreateInboxRequest(**params))
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {}
if input_data.username:
params["username"] = input_data.username
if input_data.domain:
params["domain"] = input_data.domain
if input_data.display_name:
params["display_name"] = input_data.display_name
inbox = await self.create_inbox(credentials, **params)
result = inbox.model_dump()
yield "inbox_id", inbox.inbox_id
yield "email_address", inbox.inbox_id
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailGetInboxBlock(Block):
"""
Retrieve details of an existing AgentMail inbox by its ID or email address.
Returns the inbox metadata including email address, display name, and
configuration. Use this to check if an inbox exists or get its properties.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to look up (e.g. 'support@agentmail.to')"
)
class Output(BlockSchemaOutput):
inbox_id: str = SchemaField(description="Unique identifier of the inbox")
email_address: str = SchemaField(description="Full email address of the inbox")
display_name: str = SchemaField(
description="Friendly name shown in the 'From' field", default=""
)
result: dict = SchemaField(
description="Complete inbox object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b858f62b-6c12-4736-aaf2-dbc5a9281320",
description="Retrieve details of an existing AgentMail inbox including its email address, display name, and configuration.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
},
test_output=[
("inbox_id", "test-inbox"),
("email_address", "test-inbox"),
("display_name", ""),
("result", dict),
],
test_mock={
"get_inbox": lambda *a, **kw: type(
"Inbox",
(),
{
"inbox_id": "test-inbox",
"display_name": "",
"model_dump": lambda self: {"inbox_id": "test-inbox"},
},
)(),
},
)
@staticmethod
async def get_inbox(credentials: APIKeyCredentials, inbox_id: str):
client = _client(credentials)
return await client.inboxes.get(inbox_id=inbox_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
inbox = await self.get_inbox(credentials, input_data.inbox_id)
result = inbox.model_dump()
yield "inbox_id", inbox.inbox_id
yield "email_address", inbox.inbox_id
yield "display_name", inbox.display_name or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailListInboxesBlock(Block):
"""
List all email inboxes in your AgentMail organization.
Returns a paginated list of all inboxes with their metadata.
Use page_token for pagination when you have many inboxes.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
limit: int = SchemaField(
description="Maximum number of inboxes to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page of results",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
inboxes: list[dict] = SchemaField(
description="List of inbox objects, each containing inbox_id, email_address, display_name, etc."
)
count: int = SchemaField(
description="Total number of inboxes in your organization"
)
next_page_token: str = SchemaField(
description="Token to pass as page_token to get the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="cfd84a06-2121-4cef-8d14-8badf52d22f0",
description="List all email inboxes in your AgentMail organization with pagination support.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
("inboxes", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_inboxes": lambda *a, **kw: type(
"Resp",
(),
{
"inboxes": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_inboxes(credentials: APIKeyCredentials, **params):
client = _client(credentials)
return await client.inboxes.list(**params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
response = await self.list_inboxes(credentials, **params)
inboxes = [i.model_dump() for i in response.inboxes]
yield "inboxes", inboxes
yield "count", (c if (c := response.count) is not None else len(inboxes))
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailUpdateInboxBlock(Block):
"""
Update the display name of an existing AgentMail inbox.
Changes the friendly name shown in the 'From' field when emails are sent
from this inbox. The email address itself cannot be changed.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to update (e.g. 'support@agentmail.to')"
)
display_name: str = SchemaField(
description="New display name for the inbox (e.g. 'Customer Support Bot')"
)
class Output(BlockSchemaOutput):
inbox_id: str = SchemaField(description="The updated inbox ID")
result: dict = SchemaField(
description="Complete updated inbox object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="59b49f59-a6d1-4203-94c0-3908adac50b6",
description="Update the display name of an AgentMail inbox. Changes the 'From' name shown when emails are sent.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"display_name": "Updated",
},
test_output=[
("inbox_id", "test-inbox"),
("result", dict),
],
test_mock={
"update_inbox": lambda *a, **kw: type(
"Inbox",
(),
{
"inbox_id": "test-inbox",
"model_dump": lambda self: {"inbox_id": "test-inbox"},
},
)(),
},
)
@staticmethod
async def update_inbox(credentials: APIKeyCredentials, inbox_id: str, **params):
client = _client(credentials)
return await client.inboxes.update(inbox_id=inbox_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
inbox = await self.update_inbox(
credentials,
input_data.inbox_id,
display_name=input_data.display_name,
)
result = inbox.model_dump()
yield "inbox_id", inbox.inbox_id
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailDeleteInboxBlock(Block):
"""
Permanently delete an AgentMail inbox and all its data.
This removes the inbox, all its messages, threads, and drafts.
This action cannot be undone. The email address will no longer
receive or send emails.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to permanently delete"
)
class Output(BlockSchemaOutput):
success: bool = SchemaField(
description="True if the inbox was successfully deleted"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="ade970ae-8428-4a7b-9278-b52054dbf535",
description="Permanently delete an AgentMail inbox and all its messages, threads, and drafts. This action cannot be undone.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
},
test_output=[("success", True)],
test_mock={
"delete_inbox": lambda *a, **kw: None,
},
)
@staticmethod
async def delete_inbox(credentials: APIKeyCredentials, inbox_id: str):
client = _client(credentials)
await client.inboxes.delete(inbox_id=inbox_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
await self.delete_inbox(credentials, input_data.inbox_id)
yield "success", True
except Exception as e:
yield "error", str(e)

View File

@@ -1,384 +0,0 @@
"""
AgentMail List blocks — manage allow/block lists for email filtering.
Lists let you control which email addresses and domains your agents can
send to or receive from. There are four list types based on two dimensions:
direction (send/receive) and type (allow/block).
- receive + allow: Only accept emails from these addresses/domains
- receive + block: Reject emails from these addresses/domains
- send + allow: Only send emails to these addresses/domains
- send + block: Prevent sending emails to these addresses/domains
"""
from enum import Enum
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class ListDirection(str, Enum):
SEND = "send"
RECEIVE = "receive"
class ListType(str, Enum):
ALLOW = "allow"
BLOCK = "block"
class AgentMailListEntriesBlock(Block):
"""
List all entries in an AgentMail allow/block list.
Retrieves email addresses and domains that are currently allowed
or blocked for sending or receiving. Use direction and list_type
to select which of the four lists to query.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
direction: ListDirection = SchemaField(
description="'send' to filter outgoing emails, 'receive' to filter incoming emails"
)
list_type: ListType = SchemaField(
description="'allow' for whitelist (only permit these), 'block' for blacklist (reject these)"
)
limit: int = SchemaField(
description="Maximum number of entries to return per page",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
entries: list[dict] = SchemaField(
description="List of entries, each with an email address or domain"
)
count: int = SchemaField(description="Number of entries returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="01489100-35da-45aa-8a01-9540ba0e9a21",
description="List all entries in an AgentMail allow/block list. Choose send/receive direction and allow/block type.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"direction": "receive",
"list_type": "block",
},
test_output=[
("entries", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_entries": lambda *a, **kw: type(
"Resp",
(),
{
"entries": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_entries(
credentials: APIKeyCredentials, direction: str, list_type: str, **params
):
client = _client(credentials)
return await client.lists.list(direction, list_type, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
response = await self.list_entries(
credentials,
input_data.direction.value,
input_data.list_type.value,
**params,
)
entries = [e.model_dump() for e in response.entries]
yield "entries", entries
yield "count", (c if (c := response.count) is not None else len(entries))
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailCreateListEntryBlock(Block):
"""
Add an email address or domain to an AgentMail allow/block list.
Entries can be full email addresses (e.g. 'partner@example.com') or
entire domains (e.g. 'example.com'). For block lists, you can optionally
provide a reason (e.g. 'spam', 'competitor').
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
direction: ListDirection = SchemaField(
description="'send' for outgoing email rules, 'receive' for incoming email rules"
)
list_type: ListType = SchemaField(
description="'allow' to whitelist, 'block' to blacklist"
)
entry: str = SchemaField(
description="Email address (user@example.com) or domain (example.com) to add"
)
reason: str = SchemaField(
description="Reason for blocking (only used with block lists, e.g. 'spam', 'competitor')",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
entry: str = SchemaField(
description="The email address or domain that was added"
)
result: dict = SchemaField(description="Complete entry object")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b6650a0a-b113-40cf-8243-ff20f684f9b8",
description="Add an email address or domain to an allow/block list. Block spam senders or whitelist trusted domains.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"direction": "receive",
"list_type": "block",
"entry": "spam@example.com",
},
test_output=[
("entry", "spam@example.com"),
("result", dict),
],
test_mock={
"create_entry": lambda *a, **kw: type(
"Entry",
(),
{
"model_dump": lambda self: {"entry": "spam@example.com"},
},
)(),
},
)
@staticmethod
async def create_entry(
credentials: APIKeyCredentials, direction: str, list_type: str, **params
):
client = _client(credentials)
return await client.lists.create(direction, list_type, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"entry": input_data.entry}
if input_data.reason and input_data.list_type == ListType.BLOCK:
params["reason"] = input_data.reason
result = await self.create_entry(
credentials,
input_data.direction.value,
input_data.list_type.value,
**params,
)
result_dict = result.model_dump()
yield "entry", input_data.entry
yield "result", result_dict
except Exception as e:
yield "error", str(e)
class AgentMailGetListEntryBlock(Block):
"""
Check if an email address or domain exists in an AgentMail allow/block list.
Returns the entry details if found. Use this to verify whether a specific
address or domain is currently allowed or blocked.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
direction: ListDirection = SchemaField(
description="'send' for outgoing rules, 'receive' for incoming rules"
)
list_type: ListType = SchemaField(
description="'allow' for whitelist, 'block' for blacklist"
)
entry: str = SchemaField(description="Email address or domain to look up")
class Output(BlockSchemaOutput):
entry: str = SchemaField(
description="The email address or domain that was found"
)
result: dict = SchemaField(description="Complete entry object with metadata")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="fb117058-ab27-40d1-9231-eb1dd526fc7a",
description="Check if an email address or domain is in an allow/block list. Verify filtering rules.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"direction": "receive",
"list_type": "block",
"entry": "spam@example.com",
},
test_output=[
("entry", "spam@example.com"),
("result", dict),
],
test_mock={
"get_entry": lambda *a, **kw: type(
"Entry",
(),
{
"model_dump": lambda self: {"entry": "spam@example.com"},
},
)(),
},
)
@staticmethod
async def get_entry(
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
):
client = _client(credentials)
return await client.lists.get(direction, list_type, entry=entry)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
result = await self.get_entry(
credentials,
input_data.direction.value,
input_data.list_type.value,
input_data.entry,
)
result_dict = result.model_dump()
yield "entry", input_data.entry
yield "result", result_dict
except Exception as e:
yield "error", str(e)
class AgentMailDeleteListEntryBlock(Block):
"""
Remove an email address or domain from an AgentMail allow/block list.
After removal, the address/domain will no longer be filtered by this list.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
direction: ListDirection = SchemaField(
description="'send' for outgoing rules, 'receive' for incoming rules"
)
list_type: ListType = SchemaField(
description="'allow' for whitelist, 'block' for blacklist"
)
entry: str = SchemaField(
description="Email address or domain to remove from the list"
)
class Output(BlockSchemaOutput):
success: bool = SchemaField(
description="True if the entry was successfully removed"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="2b8d57f1-1c9e-470f-a70b-5991c80fad5f",
description="Remove an email address or domain from an allow/block list to stop filtering it.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"direction": "receive",
"list_type": "block",
"entry": "spam@example.com",
},
test_output=[("success", True)],
test_mock={
"delete_entry": lambda *a, **kw: None,
},
)
@staticmethod
async def delete_entry(
credentials: APIKeyCredentials, direction: str, list_type: str, entry: str
):
client = _client(credentials)
await client.lists.delete(direction, list_type, entry=entry)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
await self.delete_entry(
credentials,
input_data.direction.value,
input_data.list_type.value,
input_data.entry,
)
yield "success", True
except Exception as e:
yield "error", str(e)

View File

@@ -1,695 +0,0 @@
"""
AgentMail Message blocks — send, list, get, reply, forward, and update messages.
A Message is an individual email within a Thread. Agents can send new messages
(which create threads), reply to existing messages, forward them, and manage
labels for state tracking (e.g. read/unread, campaign tags).
"""
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class AgentMailSendMessageBlock(Block):
"""
Send a new email from an AgentMail inbox, automatically creating a new thread.
Supports plain text and HTML bodies, CC/BCC recipients, and labels for
organizing messages (e.g. campaign tracking, state management).
Max 50 combined recipients across to, cc, and bcc.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to send from (e.g. 'agent@agentmail.to')"
)
to: list[str] = SchemaField(
description="Recipient email addresses (e.g. ['user@example.com'])"
)
subject: str = SchemaField(description="Email subject line")
text: str = SchemaField(
description="Plain text body of the email. Always provide this as a fallback for email clients that don't render HTML."
)
html: str = SchemaField(
description="Rich HTML body of the email. Embed CSS in a <style> tag for best compatibility across email clients.",
default="",
advanced=True,
)
cc: list[str] = SchemaField(
description="CC recipient email addresses for human-in-the-loop oversight",
default_factory=list,
advanced=True,
)
bcc: list[str] = SchemaField(
description="BCC recipient email addresses (hidden from other recipients)",
default_factory=list,
advanced=True,
)
labels: list[str] = SchemaField(
description="Labels to tag the message for filtering and state management (e.g. ['outreach', 'q4-campaign'])",
default_factory=list,
advanced=True,
)
class Output(BlockSchemaOutput):
message_id: str = SchemaField(
description="Unique identifier of the sent message"
)
thread_id: str = SchemaField(
description="Thread ID grouping this message and any future replies"
)
result: dict = SchemaField(
description="Complete sent message object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b67469b2-7748-4d81-a223-4ebd332cca89",
description="Send a new email from an AgentMail inbox. Creates a new conversation thread. Supports HTML, CC/BCC, and labels.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"to": ["user@example.com"],
"subject": "Test",
"text": "Hello",
},
test_output=[
("message_id", "mock-msg-id"),
("thread_id", "mock-thread-id"),
("result", dict),
],
test_mock={
"send_message": lambda *a, **kw: type(
"Msg",
(),
{
"message_id": "mock-msg-id",
"thread_id": "mock-thread-id",
"model_dump": lambda self: {
"message_id": "mock-msg-id",
"thread_id": "mock-thread-id",
},
},
)(),
},
)
@staticmethod
async def send_message(credentials: APIKeyCredentials, inbox_id: str, **params):
client = _client(credentials)
return await client.inboxes.messages.send(inbox_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
if total > 50:
raise ValueError(
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
)
params: dict = {
"to": input_data.to,
"subject": input_data.subject,
"text": input_data.text,
}
if input_data.html:
params["html"] = input_data.html
if input_data.cc:
params["cc"] = input_data.cc
if input_data.bcc:
params["bcc"] = input_data.bcc
if input_data.labels:
params["labels"] = input_data.labels
msg = await self.send_message(credentials, input_data.inbox_id, **params)
result = msg.model_dump()
yield "message_id", msg.message_id
yield "thread_id", msg.thread_id or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailListMessagesBlock(Block):
"""
List all messages in an AgentMail inbox with optional label filtering.
Returns a paginated list of messages. Use labels to filter (e.g.
labels=['unread'] to only get unprocessed messages). Useful for
polling workflows or building inbox views.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to list messages from"
)
limit: int = SchemaField(
description="Maximum number of messages to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
labels: list[str] = SchemaField(
description="Only return messages with ALL of these labels (e.g. ['unread'] or ['q4-campaign', 'follow-up'])",
default_factory=list,
advanced=True,
)
class Output(BlockSchemaOutput):
messages: list[dict] = SchemaField(
description="List of message objects with subject, sender, text, html, labels, etc."
)
count: int = SchemaField(description="Number of messages returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="721234df-c7a2-4927-b205-744badbd5844",
description="List messages in an AgentMail inbox. Filter by labels to find unread, campaign-tagged, or categorized messages.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
},
test_output=[
("messages", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_messages": lambda *a, **kw: type(
"Resp",
(),
{
"messages": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_messages(credentials: APIKeyCredentials, inbox_id: str, **params):
client = _client(credentials)
return await client.inboxes.messages.list(inbox_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
if input_data.labels:
params["labels"] = input_data.labels
response = await self.list_messages(
credentials, input_data.inbox_id, **params
)
messages = [m.model_dump() for m in response.messages]
yield "messages", messages
yield "count", (c if (c := response.count) is not None else len(messages))
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailGetMessageBlock(Block):
"""
Retrieve a specific email message by ID from an AgentMail inbox.
Returns the full message including subject, body (text and HTML),
sender, recipients, and attachments. Use extracted_text to get
only the new reply content without quoted history.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the message belongs to"
)
message_id: str = SchemaField(
description="Message ID to retrieve (e.g. '<abc123@agentmail.to>')"
)
class Output(BlockSchemaOutput):
message_id: str = SchemaField(description="Unique identifier of the message")
thread_id: str = SchemaField(description="Thread this message belongs to")
subject: str = SchemaField(description="Email subject line")
text: str = SchemaField(
description="Full plain text body (may include quoted reply history)"
)
extracted_text: str = SchemaField(
description="Just the new reply content with quoted history stripped. Best for AI processing.",
default="",
)
html: str = SchemaField(description="HTML body of the email", default="")
result: dict = SchemaField(
description="Complete message object with all fields including sender, recipients, attachments, labels"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="2788bdfa-1527-4603-a5e4-a455c05c032f",
description="Retrieve a specific email message by ID. Includes extracted_text for clean reply content without quoted history.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"message_id": "test-msg",
},
test_output=[
("message_id", "test-msg"),
("thread_id", "t1"),
("subject", "Hi"),
("text", "Hello"),
("extracted_text", "Hello"),
("html", ""),
("result", dict),
],
test_mock={
"get_message": lambda *a, **kw: type(
"Msg",
(),
{
"message_id": "test-msg",
"thread_id": "t1",
"subject": "Hi",
"text": "Hello",
"extracted_text": "Hello",
"html": "",
"model_dump": lambda self: {"message_id": "test-msg"},
},
)(),
},
)
@staticmethod
async def get_message(
credentials: APIKeyCredentials,
inbox_id: str,
message_id: str,
):
client = _client(credentials)
return await client.inboxes.messages.get(
inbox_id=inbox_id, message_id=message_id
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
msg = await self.get_message(
credentials, input_data.inbox_id, input_data.message_id
)
result = msg.model_dump()
yield "message_id", msg.message_id
yield "thread_id", msg.thread_id or ""
yield "subject", msg.subject or ""
yield "text", msg.text or ""
yield "extracted_text", msg.extracted_text or ""
yield "html", msg.html or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailReplyToMessageBlock(Block):
"""
Reply to an existing email message, keeping the reply in the same thread.
The reply is automatically added to the same conversation thread as the
original message. Use this for multi-turn agent conversations.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to send the reply from"
)
message_id: str = SchemaField(
description="Message ID to reply to (e.g. '<abc123@agentmail.to>')"
)
text: str = SchemaField(description="Plain text body of the reply")
html: str = SchemaField(
description="Rich HTML body of the reply",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
message_id: str = SchemaField(
description="Unique identifier of the reply message"
)
thread_id: str = SchemaField(description="Thread ID the reply was added to")
result: dict = SchemaField(
description="Complete reply message object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b9fe53fa-5026-4547-9570-b54ccb487229",
description="Reply to an existing email in the same conversation thread. Use for multi-turn agent conversations.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"message_id": "test-msg",
"text": "Reply",
},
test_output=[
("message_id", "mock-reply-id"),
("thread_id", "mock-thread-id"),
("result", dict),
],
test_mock={
"reply_to_message": lambda *a, **kw: type(
"Msg",
(),
{
"message_id": "mock-reply-id",
"thread_id": "mock-thread-id",
"model_dump": lambda self: {"message_id": "mock-reply-id"},
},
)(),
},
)
@staticmethod
async def reply_to_message(
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
):
client = _client(credentials)
return await client.inboxes.messages.reply(
inbox_id=inbox_id, message_id=message_id, **params
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"text": input_data.text}
if input_data.html:
params["html"] = input_data.html
reply = await self.reply_to_message(
credentials,
input_data.inbox_id,
input_data.message_id,
**params,
)
result = reply.model_dump()
yield "message_id", reply.message_id
yield "thread_id", reply.thread_id or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailForwardMessageBlock(Block):
"""
Forward an existing email message to one or more recipients.
Sends the original message content to different email addresses.
Optionally prepend additional text or override the subject line.
Max 50 combined recipients across to, cc, and bcc.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to forward from"
)
message_id: str = SchemaField(description="Message ID to forward")
to: list[str] = SchemaField(
description="Recipient email addresses to forward the message to (e.g. ['user@example.com'])"
)
cc: list[str] = SchemaField(
description="CC recipient email addresses",
default_factory=list,
advanced=True,
)
bcc: list[str] = SchemaField(
description="BCC recipient email addresses (hidden from other recipients)",
default_factory=list,
advanced=True,
)
subject: str = SchemaField(
description="Override the subject line (defaults to 'Fwd: <original subject>')",
default="",
advanced=True,
)
text: str = SchemaField(
description="Additional plain text to prepend before the forwarded content",
default="",
advanced=True,
)
html: str = SchemaField(
description="Additional HTML to prepend before the forwarded content",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
message_id: str = SchemaField(
description="Unique identifier of the forwarded message"
)
thread_id: str = SchemaField(description="Thread ID of the forward")
result: dict = SchemaField(
description="Complete forwarded message object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b70c7e33-5d66-4f8e-897f-ac73a7bfce82",
description="Forward an email message to one or more recipients. Supports CC/BCC and optional extra text or subject override.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"message_id": "test-msg",
"to": ["user@example.com"],
},
test_output=[
("message_id", "mock-fwd-id"),
("thread_id", "mock-thread-id"),
("result", dict),
],
test_mock={
"forward_message": lambda *a, **kw: type(
"Msg",
(),
{
"message_id": "mock-fwd-id",
"thread_id": "mock-thread-id",
"model_dump": lambda self: {"message_id": "mock-fwd-id"},
},
)(),
},
)
@staticmethod
async def forward_message(
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
):
client = _client(credentials)
return await client.inboxes.messages.forward(
inbox_id=inbox_id, message_id=message_id, **params
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
total = len(input_data.to) + len(input_data.cc) + len(input_data.bcc)
if total > 50:
raise ValueError(
f"Max 50 combined recipients across to, cc, and bcc (got {total})"
)
params: dict = {"to": input_data.to}
if input_data.cc:
params["cc"] = input_data.cc
if input_data.bcc:
params["bcc"] = input_data.bcc
if input_data.subject:
params["subject"] = input_data.subject
if input_data.text:
params["text"] = input_data.text
if input_data.html:
params["html"] = input_data.html
fwd = await self.forward_message(
credentials,
input_data.inbox_id,
input_data.message_id,
**params,
)
result = fwd.model_dump()
yield "message_id", fwd.message_id
yield "thread_id", fwd.thread_id or ""
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailUpdateMessageBlock(Block):
"""
Add or remove labels on an email message for state management.
Labels are string tags used to track message state (read/unread),
categorize messages (billing, support), or tag campaigns (q4-outreach).
Common pattern: add 'read' and remove 'unread' after processing a message.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the message belongs to"
)
message_id: str = SchemaField(description="Message ID to update labels on")
add_labels: list[str] = SchemaField(
description="Labels to add (e.g. ['read', 'processed', 'high-priority'])",
default_factory=list,
)
remove_labels: list[str] = SchemaField(
description="Labels to remove (e.g. ['unread', 'pending'])",
default_factory=list,
)
class Output(BlockSchemaOutput):
message_id: str = SchemaField(description="The updated message ID")
result: dict = SchemaField(
description="Complete updated message object with current labels"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="694ff816-4c89-4a5e-a552-8c31be187735",
description="Add or remove labels on an email message. Use for read/unread tracking, campaign tagging, or state management.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"message_id": "test-msg",
"add_labels": ["read"],
},
test_output=[
("message_id", "test-msg"),
("result", dict),
],
test_mock={
"update_message": lambda *a, **kw: type(
"Msg",
(),
{
"message_id": "test-msg",
"model_dump": lambda self: {"message_id": "test-msg"},
},
)(),
},
)
@staticmethod
async def update_message(
credentials: APIKeyCredentials, inbox_id: str, message_id: str, **params
):
client = _client(credentials)
return await client.inboxes.messages.update(
inbox_id=inbox_id, message_id=message_id, **params
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
if not input_data.add_labels and not input_data.remove_labels:
raise ValueError(
"Must specify at least one label operation: add_labels or remove_labels"
)
params: dict = {}
if input_data.add_labels:
params["add_labels"] = input_data.add_labels
if input_data.remove_labels:
params["remove_labels"] = input_data.remove_labels
msg = await self.update_message(
credentials,
input_data.inbox_id,
input_data.message_id,
**params,
)
result = msg.model_dump()
yield "message_id", msg.message_id
yield "result", result
except Exception as e:
yield "error", str(e)

View File

@@ -1,651 +0,0 @@
"""
AgentMail Pod blocks — create, get, list, delete pods and list pod-scoped resources.
Pods provide multi-tenant isolation between your customers. Each pod acts as
an isolated workspace containing its own inboxes, domains, threads, and drafts.
Use pods when building SaaS platforms, agency tools, or AI agent fleets that
serve multiple customers.
"""
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class AgentMailCreatePodBlock(Block):
"""
Create a new pod for multi-tenant customer isolation.
Each pod acts as an isolated workspace for one customer or tenant.
Use client_id to map pods to your internal tenant IDs for idempotent
creation (safe to retry without creating duplicates).
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
client_id: str = SchemaField(
description="Your internal tenant/customer ID for idempotent mapping. Lets you access the pod by your own ID instead of AgentMail's pod_id.",
default="",
)
class Output(BlockSchemaOutput):
pod_id: str = SchemaField(description="Unique identifier of the created pod")
result: dict = SchemaField(description="Complete pod object with all metadata")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="a2db9784-2d17-4f8f-9d6b-0214e6f22101",
description="Create a new pod for multi-tenant customer isolation. Use client_id to map to your internal tenant IDs.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
("pod_id", "mock-pod-id"),
("result", dict),
],
test_mock={
"create_pod": lambda *a, **kw: type(
"Pod",
(),
{
"pod_id": "mock-pod-id",
"model_dump": lambda self: {"pod_id": "mock-pod-id"},
},
)(),
},
)
@staticmethod
async def create_pod(credentials: APIKeyCredentials, **params):
client = _client(credentials)
return await client.pods.create(**params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {}
if input_data.client_id:
params["client_id"] = input_data.client_id
pod = await self.create_pod(credentials, **params)
result = pod.model_dump()
yield "pod_id", pod.pod_id
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailGetPodBlock(Block):
"""
Retrieve details of an existing pod by its ID.
Returns the pod metadata including its client_id mapping and
creation timestamp.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
pod_id: str = SchemaField(description="Pod ID to retrieve")
class Output(BlockSchemaOutput):
pod_id: str = SchemaField(description="Unique identifier of the pod")
result: dict = SchemaField(description="Complete pod object with all metadata")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="553361bc-bb1b-4322-9ad4-0c226200217e",
description="Retrieve details of an existing pod including its client_id mapping and metadata.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
test_output=[
("pod_id", "test-pod"),
("result", dict),
],
test_mock={
"get_pod": lambda *a, **kw: type(
"Pod",
(),
{
"pod_id": "test-pod",
"model_dump": lambda self: {"pod_id": "test-pod"},
},
)(),
},
)
@staticmethod
async def get_pod(credentials: APIKeyCredentials, pod_id: str):
client = _client(credentials)
return await client.pods.get(pod_id=pod_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
pod = await self.get_pod(credentials, pod_id=input_data.pod_id)
result = pod.model_dump()
yield "pod_id", pod.pod_id
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailListPodsBlock(Block):
"""
List all pods in your AgentMail organization.
Returns a paginated list of all tenant pods with their metadata.
Use this to see all customer workspaces at a glance.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
limit: int = SchemaField(
description="Maximum number of pods to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
pods: list[dict] = SchemaField(
description="List of pod objects with pod_id, client_id, creation time, etc."
)
count: int = SchemaField(description="Number of pods returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="9d3725ee-2968-431a-a816-857ab41e1420",
description="List all tenant pods in your organization. See all customer workspaces at a glance.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
("pods", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_pods": lambda *a, **kw: type(
"Resp",
(),
{
"pods": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_pods(credentials: APIKeyCredentials, **params):
client = _client(credentials)
return await client.pods.list(**params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
response = await self.list_pods(credentials, **params)
pods = [p.model_dump() for p in response.pods]
yield "pods", pods
yield "count", response.count
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailDeletePodBlock(Block):
"""
Permanently delete a pod. All inboxes and domains must be removed first.
You cannot delete a pod that still contains inboxes or domains.
Delete all child resources first, then delete the pod.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
pod_id: str = SchemaField(
description="Pod ID to permanently delete (must have no inboxes or domains)"
)
class Output(BlockSchemaOutput):
success: bool = SchemaField(
description="True if the pod was successfully deleted"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="f371f8cd-682d-4f5f-905c-529c74a8fb35",
description="Permanently delete a pod. All inboxes and domains must be removed first.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
test_output=[("success", True)],
test_mock={
"delete_pod": lambda *a, **kw: None,
},
)
@staticmethod
async def delete_pod(credentials: APIKeyCredentials, pod_id: str):
client = _client(credentials)
await client.pods.delete(pod_id=pod_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
await self.delete_pod(credentials, pod_id=input_data.pod_id)
yield "success", True
except Exception as e:
yield "error", str(e)
class AgentMailListPodInboxesBlock(Block):
"""
List all inboxes within a specific pod (customer workspace).
Returns only the inboxes belonging to this pod, providing
tenant-scoped visibility.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
pod_id: str = SchemaField(description="Pod ID to list inboxes from")
limit: int = SchemaField(
description="Maximum number of inboxes to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
inboxes: list[dict] = SchemaField(
description="List of inbox objects within this pod"
)
count: int = SchemaField(description="Number of inboxes returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="a8c17ce0-b7c1-4bc3-ae39-680e1952e5d0",
description="List all inboxes within a pod. View email accounts scoped to a specific customer.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
test_output=[
("inboxes", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_pod_inboxes": lambda *a, **kw: type(
"Resp",
(),
{
"inboxes": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_pod_inboxes(credentials: APIKeyCredentials, pod_id: str, **params):
client = _client(credentials)
return await client.pods.inboxes.list(pod_id=pod_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
response = await self.list_pod_inboxes(
credentials, pod_id=input_data.pod_id, **params
)
inboxes = [i.model_dump() for i in response.inboxes]
yield "inboxes", inboxes
yield "count", response.count
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailListPodThreadsBlock(Block):
"""
List all conversation threads across all inboxes within a pod.
Returns threads from every inbox in the pod. Use for building
per-customer dashboards showing all email activity, or for
supervisor agents monitoring a customer's conversations.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
pod_id: str = SchemaField(description="Pod ID to list threads from")
limit: int = SchemaField(
description="Maximum number of threads to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
labels: list[str] = SchemaField(
description="Only return threads matching ALL of these labels",
default_factory=list,
advanced=True,
)
class Output(BlockSchemaOutput):
threads: list[dict] = SchemaField(
description="List of thread objects from all inboxes in this pod"
)
count: int = SchemaField(description="Number of threads returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="80214f08-8b85-4533-a6b8-f8123bfcb410",
description="List all conversation threads across all inboxes within a pod. View all email activity for a customer.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
test_output=[
("threads", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_pod_threads": lambda *a, **kw: type(
"Resp",
(),
{
"threads": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_pod_threads(credentials: APIKeyCredentials, pod_id: str, **params):
client = _client(credentials)
return await client.pods.threads.list(pod_id=pod_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
if input_data.labels:
params["labels"] = input_data.labels
response = await self.list_pod_threads(
credentials, pod_id=input_data.pod_id, **params
)
threads = [t.model_dump() for t in response.threads]
yield "threads", threads
yield "count", response.count
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailListPodDraftsBlock(Block):
"""
List all drafts across all inboxes within a pod.
Returns pending drafts from every inbox in the pod. Use for
per-customer approval dashboards or monitoring scheduled sends.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
pod_id: str = SchemaField(description="Pod ID to list drafts from")
limit: int = SchemaField(
description="Maximum number of drafts to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
class Output(BlockSchemaOutput):
drafts: list[dict] = SchemaField(
description="List of draft objects from all inboxes in this pod"
)
count: int = SchemaField(description="Number of drafts returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="12fd7a3e-51ad-4b20-97c1-0391f207f517",
description="List all drafts across all inboxes within a pod. View pending emails for a customer.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
test_output=[
("drafts", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_pod_drafts": lambda *a, **kw: type(
"Resp",
(),
{
"drafts": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_pod_drafts(credentials: APIKeyCredentials, pod_id: str, **params):
client = _client(credentials)
return await client.pods.drafts.list(pod_id=pod_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
response = await self.list_pod_drafts(
credentials, pod_id=input_data.pod_id, **params
)
drafts = [d.model_dump() for d in response.drafts]
yield "drafts", drafts
yield "count", response.count
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailCreatePodInboxBlock(Block):
"""
Create a new email inbox within a specific pod (customer workspace).
The inbox is automatically scoped to the pod and inherits its
isolation guarantees. If username/domain are not provided,
AgentMail auto-generates a unique address.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
pod_id: str = SchemaField(description="Pod ID to create the inbox in")
username: str = SchemaField(
description="Local part of the email address (e.g. 'support'). Leave empty to auto-generate.",
default="",
)
domain: str = SchemaField(
description="Email domain (e.g. 'mydomain.com'). Defaults to agentmail.to if empty.",
default="",
)
display_name: str = SchemaField(
description="Friendly name shown in the 'From' field (e.g. 'Customer Support')",
default="",
)
class Output(BlockSchemaOutput):
inbox_id: str = SchemaField(
description="Unique identifier of the created inbox"
)
email_address: str = SchemaField(description="Full email address of the inbox")
result: dict = SchemaField(
description="Complete inbox object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="c6862373-1ac6-402e-89e6-7db1fea882af",
description="Create a new email inbox within a pod. The inbox is scoped to the customer workspace.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT, "pod_id": "test-pod"},
test_output=[
("inbox_id", "mock-inbox-id"),
("email_address", "mock-inbox-id"),
("result", dict),
],
test_mock={
"create_pod_inbox": lambda *a, **kw: type(
"Inbox",
(),
{
"inbox_id": "mock-inbox-id",
"model_dump": lambda self: {"inbox_id": "mock-inbox-id"},
},
)(),
},
)
@staticmethod
async def create_pod_inbox(credentials: APIKeyCredentials, pod_id: str, **params):
client = _client(credentials)
return await client.pods.inboxes.create(pod_id=pod_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {}
if input_data.username:
params["username"] = input_data.username
if input_data.domain:
params["domain"] = input_data.domain
if input_data.display_name:
params["display_name"] = input_data.display_name
inbox = await self.create_pod_inbox(
credentials, pod_id=input_data.pod_id, **params
)
result = inbox.model_dump()
yield "inbox_id", inbox.inbox_id
yield "email_address", inbox.inbox_id
yield "result", result
except Exception as e:
yield "error", str(e)

View File

@@ -1,438 +0,0 @@
"""
AgentMail Thread blocks — list, get, and delete conversation threads.
A Thread groups related messages into a single conversation. Threads are
created automatically when a new message is sent and grow as replies are added.
Threads can be queried per-inbox or across the entire organization.
"""
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
CredentialsMetaInput,
SchemaField,
)
from ._config import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, _client, agent_mail
class AgentMailListInboxThreadsBlock(Block):
"""
List all conversation threads within a specific AgentMail inbox.
Returns a paginated list of threads with optional label filtering.
Use labels to find threads by campaign, status, or custom tags.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address to list threads from"
)
limit: int = SchemaField(
description="Maximum number of threads to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
labels: list[str] = SchemaField(
description="Only return threads matching ALL of these labels (e.g. ['q4-campaign', 'follow-up'])",
default_factory=list,
advanced=True,
)
class Output(BlockSchemaOutput):
threads: list[dict] = SchemaField(
description="List of thread objects with thread_id, subject, message count, labels, etc."
)
count: int = SchemaField(description="Number of threads returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="63dd9e2d-ef81-405c-b034-c031f0437334",
description="List all conversation threads in an AgentMail inbox. Filter by labels for campaign tracking or status management.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
},
test_output=[
("threads", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_threads": lambda *a, **kw: type(
"Resp",
(),
{
"threads": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_threads(credentials: APIKeyCredentials, inbox_id: str, **params):
client = _client(credentials)
return await client.inboxes.threads.list(inbox_id=inbox_id, **params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
if input_data.labels:
params["labels"] = input_data.labels
response = await self.list_threads(
credentials, input_data.inbox_id, **params
)
threads = [t.model_dump() for t in response.threads]
yield "threads", threads
yield "count", (c if (c := response.count) is not None else len(threads))
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailGetInboxThreadBlock(Block):
"""
Retrieve a single conversation thread from an AgentMail inbox.
Returns the thread with all its messages in chronological order.
Use this to get the full conversation history for context when
composing replies.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the thread belongs to"
)
thread_id: str = SchemaField(description="Thread ID to retrieve")
class Output(BlockSchemaOutput):
thread_id: str = SchemaField(description="Unique identifier of the thread")
messages: list[dict] = SchemaField(
description="All messages in the thread, in chronological order"
)
result: dict = SchemaField(
description="Complete thread object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="42866290-1479-4153-83e7-550b703e9da2",
description="Retrieve a conversation thread with all its messages. Use for getting full conversation context before replying.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"thread_id": "test-thread",
},
test_output=[
("thread_id", "test-thread"),
("messages", []),
("result", dict),
],
test_mock={
"get_thread": lambda *a, **kw: type(
"Thread",
(),
{
"thread_id": "test-thread",
"messages": [],
"model_dump": lambda self: {
"thread_id": "test-thread",
"messages": [],
},
},
)(),
},
)
@staticmethod
async def get_thread(credentials: APIKeyCredentials, inbox_id: str, thread_id: str):
client = _client(credentials)
return await client.inboxes.threads.get(inbox_id=inbox_id, thread_id=thread_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
thread = await self.get_thread(
credentials, input_data.inbox_id, input_data.thread_id
)
messages = [m.model_dump() for m in thread.messages]
result = thread.model_dump()
result["messages"] = messages
yield "thread_id", thread.thread_id
yield "messages", messages
yield "result", result
except Exception as e:
yield "error", str(e)
class AgentMailDeleteInboxThreadBlock(Block):
"""
Permanently delete a conversation thread and all its messages from an inbox.
This removes the thread and every message within it. This action
cannot be undone.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
inbox_id: str = SchemaField(
description="Inbox ID or email address the thread belongs to"
)
thread_id: str = SchemaField(description="Thread ID to permanently delete")
class Output(BlockSchemaOutput):
success: bool = SchemaField(
description="True if the thread was successfully deleted"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="18cd5f6f-4ff6-45da-8300-25a50ea7fb75",
description="Permanently delete a conversation thread and all its messages. This action cannot be undone.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
is_sensitive_action=True,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"inbox_id": "test-inbox",
"thread_id": "test-thread",
},
test_output=[("success", True)],
test_mock={
"delete_thread": lambda *a, **kw: None,
},
)
@staticmethod
async def delete_thread(
credentials: APIKeyCredentials, inbox_id: str, thread_id: str
):
client = _client(credentials)
await client.inboxes.threads.delete(inbox_id=inbox_id, thread_id=thread_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
await self.delete_thread(
credentials, input_data.inbox_id, input_data.thread_id
)
yield "success", True
except Exception as e:
yield "error", str(e)
class AgentMailListOrgThreadsBlock(Block):
"""
List conversation threads across ALL inboxes in your organization.
Unlike per-inbox listing, this returns threads from every inbox.
Ideal for building supervisor agents that monitor all conversations,
analytics dashboards, or cross-agent routing workflows.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
limit: int = SchemaField(
description="Maximum number of threads to return per page (1-100)",
default=20,
advanced=True,
)
page_token: str = SchemaField(
description="Token from a previous response to fetch the next page",
default="",
advanced=True,
)
labels: list[str] = SchemaField(
description="Only return threads matching ALL of these labels",
default_factory=list,
advanced=True,
)
class Output(BlockSchemaOutput):
threads: list[dict] = SchemaField(
description="List of thread objects from all inboxes in the organization"
)
count: int = SchemaField(description="Number of threads returned")
next_page_token: str = SchemaField(
description="Token for the next page. Empty if no more results.",
default="",
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="d7a0657b-58ab-48b2-898b-7bd94f44a708",
description="List threads across ALL inboxes in your organization. Use for supervisor agents, dashboards, or cross-agent monitoring.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
("threads", []),
("count", 0),
("next_page_token", ""),
],
test_mock={
"list_org_threads": lambda *a, **kw: type(
"Resp",
(),
{
"threads": [],
"count": 0,
"next_page_token": "",
},
)(),
},
)
@staticmethod
async def list_org_threads(credentials: APIKeyCredentials, **params):
client = _client(credentials)
return await client.threads.list(**params)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
params: dict = {"limit": input_data.limit}
if input_data.page_token:
params["page_token"] = input_data.page_token
if input_data.labels:
params["labels"] = input_data.labels
response = await self.list_org_threads(credentials, **params)
threads = [t.model_dump() for t in response.threads]
yield "threads", threads
yield "count", (c if (c := response.count) is not None else len(threads))
yield "next_page_token", response.next_page_token or ""
except Exception as e:
yield "error", str(e)
class AgentMailGetOrgThreadBlock(Block):
"""
Retrieve a single conversation thread by ID from anywhere in the organization.
Works without needing to know which inbox the thread belongs to.
Returns the thread with all its messages in chronological order.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = agent_mail.credentials_field(
description="AgentMail API key from https://console.agentmail.to"
)
thread_id: str = SchemaField(
description="Thread ID to retrieve (works across all inboxes)"
)
class Output(BlockSchemaOutput):
thread_id: str = SchemaField(description="Unique identifier of the thread")
messages: list[dict] = SchemaField(
description="All messages in the thread, in chronological order"
)
result: dict = SchemaField(
description="Complete thread object with all metadata"
)
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="39aaae31-3eb1-44c6-9e37-5a44a4529649",
description="Retrieve a conversation thread by ID from anywhere in the organization, without needing the inbox ID.",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"thread_id": "test-thread",
},
test_output=[
("thread_id", "test-thread"),
("messages", []),
("result", dict),
],
test_mock={
"get_org_thread": lambda *a, **kw: type(
"Thread",
(),
{
"thread_id": "test-thread",
"messages": [],
"model_dump": lambda self: {
"thread_id": "test-thread",
"messages": [],
},
},
)(),
},
)
@staticmethod
async def get_org_thread(credentials: APIKeyCredentials, thread_id: str):
client = _client(credentials)
return await client.threads.get(thread_id=thread_id)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
thread = await self.get_org_thread(credentials, input_data.thread_id)
messages = [m.model_dump() for m in thread.messages]
result = thread.model_dump()
result["messages"] = messages
yield "thread_id", thread.thread_id
yield "messages", messages
yield "result", result
except Exception as e:
yield "error", str(e)

View File

@@ -27,7 +27,6 @@ from backend.util.file import MediaFileType, store_media_file
class GeminiImageModel(str, Enum):
NANO_BANANA = "google/nano-banana"
NANO_BANANA_PRO = "google/nano-banana-pro"
NANO_BANANA_2 = "google/nano-banana-2"
class AspectRatio(str, Enum):
@@ -78,7 +77,7 @@ class AIImageCustomizerBlock(Block):
)
model: GeminiImageModel = SchemaField(
description="The AI model to use for image generation and editing",
default=GeminiImageModel.NANO_BANANA_2,
default=GeminiImageModel.NANO_BANANA,
title="Model",
)
images: list[MediaFileType] = SchemaField(
@@ -104,7 +103,7 @@ class AIImageCustomizerBlock(Block):
super().__init__(
id="d76bbe4c-930e-4894-8469-b66775511f71",
description=(
"Generate and edit custom images using Google's Nano-Banana models from Gemini. "
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
"Provide a prompt and optional reference images to create or modify images."
),
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
@@ -112,7 +111,7 @@ class AIImageCustomizerBlock(Block):
output_schema=AIImageCustomizerBlock.Output,
test_input={
"prompt": "Make the scene more vibrant and colorful",
"model": GeminiImageModel.NANO_BANANA_2,
"model": GeminiImageModel.NANO_BANANA,
"images": [],
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
"output_format": OutputFormat.JPG,

View File

@@ -115,7 +115,6 @@ class ImageGenModel(str, Enum):
RECRAFT = "Recraft v3"
SD3_5 = "Stable Diffusion 3.5 Medium"
NANO_BANANA_PRO = "Nano Banana Pro"
NANO_BANANA_2 = "Nano Banana 2"
class AIImageGeneratorBlock(Block):
@@ -132,7 +131,7 @@ class AIImageGeneratorBlock(Block):
)
model: ImageGenModel = SchemaField(
description="The AI model to use for image generation",
default=ImageGenModel.NANO_BANANA_2,
default=ImageGenModel.SD3_5,
title="Model",
)
size: ImageSize = SchemaField(
@@ -166,7 +165,7 @@ class AIImageGeneratorBlock(Block):
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"prompt": "An octopus using a laptop in a snowy forest with 'AutoGPT' clearly visible on the screen",
"model": ImageGenModel.NANO_BANANA_2,
"model": ImageGenModel.RECRAFT,
"size": ImageSize.SQUARE,
"style": ImageStyle.REALISTIC,
},
@@ -180,9 +179,7 @@ class AIImageGeneratorBlock(Block):
],
test_mock={
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: (
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
)
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
},
)
@@ -283,24 +280,17 @@ class AIImageGeneratorBlock(Block):
)
return output
elif input_data.model in (
ImageGenModel.NANO_BANANA_PRO,
ImageGenModel.NANO_BANANA_2,
):
# Use Nano Banana models (Google Gemini image variants)
model_map = {
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
}
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
input_params = {
"prompt": modified_prompt,
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
"resolution": "2K",
"resolution": "2K", # Default to 2K for good quality/cost balance
"output_format": "jpg",
"safety_filter_level": "block_only_high",
"safety_filter_level": "block_only_high", # Most permissive
}
output = await self._run_client(
credentials, model_map[input_data.model], input_params
credentials, "google/nano-banana-pro", input_params
)
return output

View File

@@ -1,376 +0,0 @@
from __future__ import annotations
import asyncio
import contextvars
import json
import logging
from typing import TYPE_CHECKING, Any
from typing_extensions import TypedDict # Needed for Python <3.12 compatibility
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
logger = logging.getLogger(__name__)
# Block ID shared between autopilot.py and copilot prompting.py.
AUTOPILOT_BLOCK_ID = "c069dc6b-c3ed-4c12-b6e5-d47361e64ce6"
class ToolCallEntry(TypedDict):
"""A single tool invocation record from an autopilot execution."""
tool_call_id: str
tool_name: str
input: Any
output: Any | None
success: bool | None
class TokenUsage(TypedDict):
"""Aggregated token counts from the autopilot stream."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class AutoPilotBlock(Block):
"""Execute tasks using AutoGPT AutoPilot with full access to platform tools.
The autopilot can manage agents, access workspace files, fetch web content,
run blocks, and more. This block enables sub-agent patterns (autopilot calling
autopilot) and scheduled autopilot execution via the agent executor.
"""
class Input(BlockSchemaInput):
"""Input schema for the AutoPilot block."""
prompt: str = SchemaField(
description=(
"The task or instruction for the autopilot to execute. "
"The autopilot has access to platform tools like agent management, "
"workspace files, web fetch, block execution, and more."
),
placeholder="Find my agents and list them",
advanced=False,
)
system_context: str = SchemaField(
description=(
"Optional additional context prepended to the prompt. "
"Use this to constrain autopilot behavior, provide domain "
"context, or set output format requirements."
),
default="",
advanced=True,
)
session_id: str = SchemaField(
description=(
"Session ID to continue an existing autopilot conversation. "
"Leave empty to start a new session. "
"Use the session_id output from a previous run to continue."
),
default="",
advanced=True,
)
max_recursion_depth: int = SchemaField(
description=(
"Maximum nesting depth when the autopilot calls this block "
"recursively (sub-agent pattern). Prevents infinite loops."
),
default=3,
ge=1,
le=10,
advanced=True,
)
# timeout_seconds removed: the SDK manages its own heartbeat-based
# timeouts internally; wrapping with asyncio.timeout corrupts the
# SDK's internal stream (see service.py CRITICAL comment).
class Output(BlockSchemaOutput):
"""Output schema for the AutoPilot block."""
response: str = SchemaField(
description="The final text response from the autopilot."
)
tool_calls: list[ToolCallEntry] = SchemaField(
description=(
"List of tools called during execution. Each entry has "
"tool_call_id, tool_name, input, output, and success fields."
),
)
conversation_history: str = SchemaField(
description=(
"Current turn messages (user prompt + assistant reply) as JSON. "
"It can be used for logging or analysis."
),
)
session_id: str = SchemaField(
description=(
"Session ID for this conversation. "
"Pass this back to continue the conversation in a future run."
),
)
token_usage: TokenUsage = SchemaField(
description=(
"Token usage statistics: prompt_tokens, "
"completion_tokens, total_tokens."
),
)
def __init__(self):
super().__init__(
id=AUTOPILOT_BLOCK_ID,
description=(
"Execute tasks using AutoGPT AutoPilot with full access to "
"platform tools (agent management, workspace files, web fetch, "
"block execution, and more). Enables sub-agent patterns and "
"scheduled autopilot execution."
),
categories={BlockCategory.AI, BlockCategory.AGENT},
input_schema=AutoPilotBlock.Input,
output_schema=AutoPilotBlock.Output,
test_input={
"prompt": "List my agents",
"system_context": "",
"session_id": "",
"max_recursion_depth": 3,
},
test_output=[
("response", "You have 2 agents: Agent A and Agent B."),
("tool_calls", []),
(
"conversation_history",
'[{"role": "user", "content": "List my agents"}]',
),
("session_id", "test-session-id"),
(
"token_usage",
{
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
},
),
],
test_mock={
"create_session": lambda *args, **kwargs: "test-session-id",
"execute_copilot": lambda *args, **kwargs: (
"You have 2 agents: Agent A and Agent B.",
[],
'[{"role": "user", "content": "List my agents"}]',
"test-session-id",
{
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
},
),
},
)
async def create_session(self, user_id: str) -> str:
"""Create a new chat session and return its ID (mockable for tests)."""
from backend.copilot.model import create_chat_session
session = await create_chat_session(user_id)
return session.session_id
async def execute_copilot(
self,
prompt: str,
system_context: str,
session_id: str,
max_recursion_depth: int,
user_id: str,
) -> tuple[str, list[ToolCallEntry], str, str, TokenUsage]:
"""Invoke the copilot and collect all stream results.
Delegates to :func:`collect_copilot_response` — the shared helper that
consumes ``stream_chat_completion_sdk`` without wrapping it in an
``asyncio.timeout`` (the SDK manages its own heartbeat-based timeouts).
Args:
prompt: The user task/instruction.
system_context: Optional context prepended to the prompt.
session_id: Chat session to use.
max_recursion_depth: Maximum allowed recursion nesting.
user_id: Authenticated user ID.
Returns:
A tuple of (response_text, tool_calls, history_json, session_id, usage).
"""
from backend.copilot.sdk.collect import collect_copilot_response
tokens = _check_recursion(max_recursion_depth)
try:
effective_prompt = prompt
if system_context:
effective_prompt = f"[System Context: {system_context}]\n\n{prompt}"
result = await collect_copilot_response(
session_id=session_id,
message=effective_prompt,
user_id=user_id,
)
# Build a lightweight conversation summary from streamed data.
turn_messages: list[dict[str, Any]] = [
{"role": "user", "content": effective_prompt},
]
if result.tool_calls:
turn_messages.append(
{
"role": "assistant",
"content": result.response_text,
"tool_calls": result.tool_calls,
}
)
else:
turn_messages.append(
{"role": "assistant", "content": result.response_text}
)
history_json = json.dumps(turn_messages, default=str)
tool_calls: list[ToolCallEntry] = [
{
"tool_call_id": tc["tool_call_id"],
"tool_name": tc["tool_name"],
"input": tc["input"],
"output": tc["output"],
"success": tc["success"],
}
for tc in result.tool_calls
]
usage: TokenUsage = {
"prompt_tokens": result.prompt_tokens,
"completion_tokens": result.completion_tokens,
"total_tokens": result.total_tokens,
}
return (
result.response_text,
tool_calls,
history_json,
session_id,
usage,
)
finally:
_reset_recursion(tokens)
async def run(
self,
input_data: Input,
*,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
"""Validate inputs, invoke the autopilot, and yield structured outputs.
Yields session_id even on failure so callers can inspect/resume the session.
"""
if not input_data.prompt.strip():
yield "error", "Prompt cannot be empty."
return
if not execution_context.user_id:
yield "error", "Cannot run autopilot without an authenticated user."
return
if input_data.max_recursion_depth < 1:
yield "error", "max_recursion_depth must be at least 1."
return
# Create session eagerly so the user always gets the session_id,
# even if the downstream stream fails (avoids orphaned sessions).
sid = input_data.session_id
if not sid:
sid = await self.create_session(execution_context.user_id)
# NOTE: No asyncio.timeout() here — the SDK manages its own
# heartbeat-based timeouts internally. Wrapping with asyncio.timeout
# would cancel the task mid-flight, corrupting the SDK's internal
# anyio memory stream (see service.py CRITICAL comment).
try:
response, tool_calls, history, _, usage = await self.execute_copilot(
prompt=input_data.prompt,
system_context=input_data.system_context,
session_id=sid,
max_recursion_depth=input_data.max_recursion_depth,
user_id=execution_context.user_id,
)
yield "response", response
yield "tool_calls", tool_calls
yield "conversation_history", history
yield "session_id", sid
yield "token_usage", usage
except asyncio.CancelledError:
yield "session_id", sid
yield "error", "AutoPilot execution was cancelled."
raise
except Exception as exc:
yield "session_id", sid
yield "error", str(exc)
# ---------------------------------------------------------------------------
# Helpers placed after the block class for top-down readability.
# ---------------------------------------------------------------------------
# Task-scoped recursion depth counter & chain-wide limit.
# contextvars are scoped to the current asyncio task, so concurrent
# graph executions each get independent counters.
_autopilot_recursion_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
"_autopilot_recursion_depth", default=0
)
_autopilot_recursion_limit: contextvars.ContextVar[int | None] = contextvars.ContextVar(
"_autopilot_recursion_limit", default=None
)
def _check_recursion(
max_depth: int,
) -> tuple[contextvars.Token[int], contextvars.Token[int | None]]:
"""Check and increment recursion depth.
Returns ContextVar tokens that must be passed to ``_reset_recursion``
when the caller exits to restore the previous depth.
Raises:
RuntimeError: If the current depth already meets or exceeds the limit.
"""
current = _autopilot_recursion_depth.get()
inherited = _autopilot_recursion_limit.get()
limit = max_depth if inherited is None else min(inherited, max_depth)
if current >= limit:
raise RuntimeError(
f"AutoPilot recursion depth limit reached ({limit}). "
"The autopilot has called itself too many times."
)
return (
_autopilot_recursion_depth.set(current + 1),
_autopilot_recursion_limit.set(limit),
)
def _reset_recursion(
tokens: tuple[contextvars.Token[int], contextvars.Token[int | None]],
) -> None:
"""Restore recursion depth and limit to their previous values."""
_autopilot_recursion_depth.reset(tokens[0])
_autopilot_recursion_limit.reset(tokens[1])

View File

@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code( # type: ignore[attr-defined]
execution = await sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -472,7 +472,7 @@ class AddToListBlock(Block):
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
entries_added = input_data.entries.copy()
if input_data.entry is not None:
if input_data.entry:
entries_added.append(input_data.entry)
updated_list = input_data.list.copy()

View File

@@ -21,7 +21,6 @@ from backend.data.model import (
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from backend.util.request import resolve_and_check_blocked
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -97,11 +96,8 @@ class SendEmailBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Email sent successfully")],
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
is_sensitive_action=True,
)
ALLOWED_SMTP_PORTS = {25, 465, 587, 2525}
@staticmethod
def send_email(
config: SMTPConfig,
@@ -132,17 +128,6 @@ class SendEmailBlock(Block):
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
) -> BlockOutput:
try:
# --- SSRF Protection ---
smtp_port = input_data.config.smtp_port
if smtp_port not in self.ALLOWED_SMTP_PORTS:
yield "error", (
f"SMTP port {smtp_port} is not allowed. "
f"Allowed ports: {sorted(self.ALLOWED_SMTP_PORTS)}"
)
return
await resolve_and_check_blocked(input_data.config.smtp_server)
status = self.send_email(
config=input_data.config,
to_email=input_data.to_email,
@@ -194,19 +179,7 @@ class SendEmailBlock(Block):
"was rejected by the server. "
"Please verify your account is authorized to send emails."
)
except smtplib.SMTPConnectError:
yield "error", (
f"Cannot connect to SMTP server '{input_data.config.smtp_server}' "
f"on port {input_data.config.smtp_port}."
)
except smtplib.SMTPServerDisconnected:
yield "error", (
f"SMTP server '{input_data.config.smtp_server}' "
"disconnected unexpectedly."
)
except smtplib.SMTPDataError as e:
yield "error", f"Email data rejected by server: {str(e)}"
except ValueError as e:
yield "error", str(e)
except Exception as e:
raise e

View File

@@ -34,29 +34,17 @@ TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
"title": TEST_CREDENTIALS.type,
}
class ImageEditorModel(str, Enum):
FLUX_KONTEXT_PRO = "Flux Kontext Pro"
FLUX_KONTEXT_MAX = "Flux Kontext Max"
NANO_BANANA_PRO = "Nano Banana Pro"
NANO_BANANA_2 = "Nano Banana 2"
class FluxKontextModelName(str, Enum):
PRO = "Flux Kontext Pro"
MAX = "Flux Kontext Max"
@property
def api_name(self) -> str:
_map = {
"FLUX_KONTEXT_PRO": "black-forest-labs/flux-kontext-pro",
"FLUX_KONTEXT_MAX": "black-forest-labs/flux-kontext-max",
"NANO_BANANA_PRO": "google/nano-banana-pro",
"NANO_BANANA_2": "google/nano-banana-2",
}
return _map[self.name]
# Keep old name as alias for backwards compatibility
FluxKontextModelName = ImageEditorModel
return f"black-forest-labs/flux-kontext-{self.name.lower()}"
class AspectRatio(str, Enum):
@@ -81,7 +69,7 @@ class AIImageEditorBlock(Block):
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="Replicate API key with permissions for Flux Kontext and Nano Banana models",
description="Replicate API key with permissions for Flux Kontext models",
)
prompt: str = SchemaField(
description="Text instruction describing the desired edit",
@@ -99,14 +87,14 @@ class AIImageEditorBlock(Block):
advanced=False,
)
seed: Optional[int] = SchemaField(
description="Random seed. Set for reproducible generation (Flux Kontext only; ignored by Nano Banana models)",
description="Random seed. Set for reproducible generation",
default=None,
title="Seed",
advanced=True,
)
model: ImageEditorModel = SchemaField(
model: FluxKontextModelName = SchemaField(
description="Model variant to use",
default=ImageEditorModel.NANO_BANANA_2,
default=FluxKontextModelName.PRO,
title="Model",
)
@@ -119,7 +107,7 @@ class AIImageEditorBlock(Block):
super().__init__(
id="3fd9c73d-4370-4925-a1ff-1b86b99fabfa",
description=(
"Edit images using Flux Kontext or Google Nano Banana models. Provide a prompt "
"Edit images using BlackForest Labs' Flux Kontext models. Provide a prompt "
"and optional reference image to generate a modified image."
),
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
@@ -130,7 +118,7 @@ class AIImageEditorBlock(Block):
"input_image": "data:image/png;base64,MQ==",
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
"seed": None,
"model": ImageEditorModel.NANO_BANANA_2,
"model": FluxKontextModelName.PRO,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
@@ -139,9 +127,7 @@ class AIImageEditorBlock(Block):
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
),
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -156,7 +142,7 @@ class AIImageEditorBlock(Block):
) -> BlockOutput:
result = await self.run_model(
api_key=credentials.api_key,
model=input_data.model,
model_name=input_data.model.api_name,
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
@@ -183,7 +169,7 @@ class AIImageEditorBlock(Block):
async def run_model(
self,
api_key: SecretStr,
model: ImageEditorModel,
model_name: str,
prompt: str,
input_image_b64: Optional[str],
aspect_ratio: str,
@@ -192,29 +178,12 @@ class AIImageEditorBlock(Block):
graph_exec_id: str,
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
model_name = model.api_name
is_nano_banana = model in (
ImageEditorModel.NANO_BANANA_PRO,
ImageEditorModel.NANO_BANANA_2,
)
if is_nano_banana:
input_params: dict = {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"output_format": "jpg",
"safety_filter_level": "block_only_high",
}
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
if input_image_b64:
input_params["image_input"] = [input_image_b64]
else:
input_params = {
"prompt": prompt,
"input_image": input_image_b64,
"aspect_ratio": aspect_ratio,
**({"seed": seed} if seed is not None else {}),
}
input_params = {
"prompt": prompt,
"input_image": input_image_b64,
"aspect_ratio": aspect_ratio,
**({"seed": seed} if seed is not None else {}),
}
try:
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore

View File

@@ -1,3 +0,0 @@
def github_repo_path(repo_url: str) -> str:
"""Extract 'owner/repo' from a GitHub repository URL."""
return repo_url.replace("https://github.com/", "")

View File

@@ -1,408 +0,0 @@
import asyncio
from enum import StrEnum
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import parse_data_uri, resolve_media_content
from backend.util.type import MediaFileType
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListCommitsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to list commits from",
default="main",
)
per_page: int = SchemaField(
description="Number of commits to return (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class CommitItem(TypedDict):
sha: str
message: str
author: str
date: str
url: str
commit: CommitItem = SchemaField(
title="Commit", description="A commit with its details"
)
commits: list[CommitItem] = SchemaField(
description="List of commits with their details"
)
error: str = SchemaField(description="Error message if listing commits failed")
def __init__(self):
super().__init__(
id="8b13f579-d8b6-4dc2-a140-f770428805de",
description="This block lists commits on a branch in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommitsBlock.Input,
output_schema=GithubListCommitsBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"commits",
[
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
],
),
(
"commit",
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
},
),
],
test_mock={
"list_commits": lambda *args, **kwargs: [
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
]
},
)
@staticmethod
async def list_commits(
credentials: GithubCredentials,
repo_url: str,
branch: str,
per_page: int,
page: int,
) -> list[Output.CommitItem]:
api = get_api(credentials)
commits_url = repo_url + "/commits"
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
response = await api.get(commits_url, params=params)
data = response.json()
repo_path = github_repo_path(repo_url)
return [
GithubListCommitsBlock.Output.CommitItem(
sha=c["sha"],
message=c["commit"]["message"],
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
date=(c["commit"].get("author") or {}).get("date", ""),
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
)
for c in data
]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
commits = await self.list_commits(
credentials,
input_data.repo_url,
input_data.branch,
input_data.per_page,
input_data.page,
)
yield "commits", commits
for commit in commits:
yield "commit", commit
except Exception as e:
yield "error", str(e)
class FileOperation(StrEnum):
"""File operations for GithubMultiFileCommitBlock.
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
between creation and update — the blob is placed at the given path regardless
of whether a file already exists there).
DELETE removes a file from the tree.
"""
UPSERT = "upsert"
DELETE = "delete"
class FileOperationInput(TypedDict):
path: str
# MediaFileType is a str NewType — no runtime breakage for existing callers.
content: MediaFileType
operation: FileOperation
class GithubMultiFileCommitBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch to commit to",
placeholder="feature-branch",
)
commit_message: str = SchemaField(
description="Commit message",
placeholder="Add new feature",
)
files: list[FileOperationInput] = SchemaField(
description=(
"List of file operations. Each item has: "
"'path' (file path), 'content' (file content, ignored for delete), "
"'operation' (upsert/delete)"
),
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the new commit")
url: str = SchemaField(description="URL of the new commit")
error: str = SchemaField(description="Error message if the commit failed")
def __init__(self):
super().__init__(
id="389eee51-a95e-4230-9bed-92167a327802",
description=(
"This block creates a single commit with multiple file "
"upsert/delete operations using the Git Trees API."
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMultiFileCommitBlock.Input,
output_schema=GithubMultiFileCommitBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "Add files",
"files": [
{
"path": "src/new.py",
"content": "print('hello')",
"operation": "upsert",
},
{
"path": "src/old.py",
"content": "",
"operation": "delete",
},
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "newcommitsha"),
("url", "https://github.com/owner/repo/commit/newcommitsha"),
],
test_mock={
"multi_file_commit": lambda *args, **kwargs: (
"newcommitsha",
"https://github.com/owner/repo/commit/newcommitsha",
)
},
)
@staticmethod
async def multi_file_commit(
credentials: GithubCredentials,
repo_url: str,
branch: str,
commit_message: str,
files: list[FileOperationInput],
) -> tuple[str, str]:
api = get_api(credentials)
safe_branch = quote(branch, safe="")
# 1. Get the latest commit SHA for the branch
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
response = await api.get(ref_url)
ref_data = response.json()
latest_commit_sha = ref_data["object"]["sha"]
# 2. Get the tree SHA of the latest commit
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
response = await api.get(commit_url)
commit_data = response.json()
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": encoding},
)
return blob_response.json()["sha"]
tree_entries: list[dict] = []
upsert_files = []
for file_op in files:
path = file_op["path"]
operation = FileOperation(file_op.get("operation", "upsert"))
if operation == FileOperation.DELETE:
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": None, # null SHA = delete
}
)
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently. Data URIs (from store_media_file)
# are sent as base64 blobs to preserve binary content.
if upsert_files:
async def _make_blob(content: str) -> str:
parsed = parse_data_uri(content)
if parsed is not None:
_, b64_payload = parsed
return await _create_blob(b64_payload, encoding="base64")
return await _create_blob(content)
blob_shas = await asyncio.gather(
*[_make_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": blob_sha,
}
)
# 4. Create a new tree
tree_url = repo_url + "/git/trees"
tree_response = await api.post(
tree_url,
json={"base_tree": base_tree_sha, "tree": tree_entries},
)
new_tree_sha = tree_response.json()["sha"]
# 5. Create a new commit
new_commit_url = repo_url + "/git/commits"
commit_response = await api.post(
new_commit_url,
json={
"message": commit_message,
"tree": new_tree_sha,
"parents": [latest_commit_sha],
},
)
new_commit_sha = commit_response.json()["sha"]
# 6. Update the branch reference
try:
await api.patch(
ref_url,
json={"sha": new_commit_sha},
)
except Exception as e:
raise RuntimeError(
f"Commit {new_commit_sha} was created but failed to update "
f"ref heads/{branch}: {e}. "
f"You can recover by manually updating the branch to {new_commit_sha}."
) from e
repo_path = github_repo_path(repo_url)
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
return new_commit_sha, commit_web_url
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
# Resolve media references (workspace://, data:, URLs) to data
# URIs so _make_blob can send binary content correctly.
resolved_files: list[FileOperationInput] = []
for file_op in input_data.files:
content = file_op.get("content", "")
operation = FileOperation(file_op.get("operation", "upsert"))
if operation != FileOperation.DELETE:
content = await resolve_media_content(
MediaFileType(content),
execution_context,
return_format="for_external_api",
)
resolved_files.append(
FileOperationInput(
path=file_op["path"],
content=MediaFileType(content),
operation=operation,
)
)
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
resolved_files,
)
yield "sha", sha
yield "url", url
except Exception as e:
yield "error", str(e)

View File

@@ -1,5 +1,4 @@
import re
from typing import Literal
from typing_extensions import TypedDict
@@ -21,8 +20,6 @@ from ._auth import (
GithubCredentialsInput,
)
MergeMethod = Literal["merge", "squash", "rebase"]
class GithubListPullRequestsBlock(Block):
class Input(BlockSchemaInput):
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
yield "reviewer", reviewer
class GithubMergePullRequestBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
pr_url: str = SchemaField(
description="URL of the GitHub pull request",
placeholder="https://github.com/owner/repo/pull/1",
)
merge_method: MergeMethod = SchemaField(
description="Merge method to use: merge, squash, or rebase",
default="merge",
)
commit_title: str = SchemaField(
description="Title for the merge commit (optional, used for merge and squash)",
default="",
)
commit_message: str = SchemaField(
description="Message for the merge commit (optional, used for merge and squash)",
default="",
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the merge commit")
merged: bool = SchemaField(description="Whether the PR was merged")
message: str = SchemaField(description="Merge status message")
error: str = SchemaField(description="Error message if the merge failed")
def __init__(self):
super().__init__(
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
description="This block merges a pull request using merge, squash, or rebase.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMergePullRequestBlock.Input,
output_schema=GithubMergePullRequestBlock.Output,
test_input={
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "abc123"),
("merged", True),
("message", "Pull Request successfully merged"),
],
test_mock={
"merge_pr": lambda *args, **kwargs: (
"abc123",
True,
"Pull Request successfully merged",
)
},
is_sensitive_action=True,
)
@staticmethod
async def merge_pr(
credentials: GithubCredentials,
pr_url: str,
merge_method: MergeMethod,
commit_title: str,
commit_message: str,
) -> tuple[str, bool, str]:
api = get_api(credentials)
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
data: dict[str, str] = {"merge_method": merge_method}
if commit_title:
data["commit_title"] = commit_title
if commit_message:
data["commit_message"] = commit_message
response = await api.put(merge_url, json=data)
result = response.json()
return result["sha"], result["merged"], result["message"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, merged, message = await self.merge_pr(
credentials,
input_data.pr_url,
input_data.merge_method,
input_data.commit_title,
input_data.commit_message,
)
yield "sha", sha
yield "merged", merged
yield "message", message
except Exception as e:
yield "error", str(e)
def prepare_pr_api_url(pr_url: str, path: str) -> str:
# Pattern to capture the base repository URL and the pull request number
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
match = re.match(pattern, pr_url)
if not match:
return pr_url
scheme, base_url, pr_number = match.groups()
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
base_url, pr_number = match.groups()
return f"{base_url}/pulls/{pr_number}/{path}"

View File

@@ -1,3 +1,5 @@
import base64
from typing_extensions import TypedDict
from backend.blocks._base import (
@@ -17,7 +19,6 @@ from ._auth import (
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListTagsBlock(Block):
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
tags_url = repo_url + "/tags"
response = await api.get(tags_url)
data = response.json()
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
tags: list[GithubListTagsBlock.Output.TagItem] = [
{
"name": tag["name"],
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
yield "tag", tag
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(branches_url)
data = response.json()
repo_path = repo_url.replace("https://github.com/", "")
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
branches = await self.list_branches(
credentials,
input_data.repo_url,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
class GithubListDiscussionsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
) -> list[Output.DiscussionItem]:
api = get_api(credentials)
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
owner, repo = repo_path.split("/")
query = """
query($owner: String!, $repo: String!, $num: Int!) {
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
yield "release", release
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to master)",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{branch}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubCreateRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
def __init__(self):
super().__init__(
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
description="This block lists all users who have starred a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListStargazersBlock.Input,
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
yield "stargazers", stargazers
for stargazer in stargazers:
yield "stargazer", stargazer
class GithubGetRepositoryInfoBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
name: str = SchemaField(description="Repository name")
full_name: str = SchemaField(description="Full repository name (owner/repo)")
description: str = SchemaField(description="Repository description")
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
private: bool = SchemaField(description="Whether the repository is private")
html_url: str = SchemaField(description="Web URL of the repository")
clone_url: str = SchemaField(description="Git clone URL")
stars: int = SchemaField(description="Number of stars")
forks: int = SchemaField(description="Number of forks")
open_issues: int = SchemaField(description="Number of open issues")
error: str = SchemaField(
description="Error message if fetching repo info failed"
)
def __init__(self):
super().__init__(
id="59d4f241-968a-4040-95da-348ac5c5ce27",
description="This block retrieves metadata about a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryInfoBlock.Input,
output_schema=GithubGetRepositoryInfoBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("name", "repo"),
("full_name", "owner/repo"),
("description", "A test repo"),
("default_branch", "main"),
("private", False),
("html_url", "https://github.com/owner/repo"),
("clone_url", "https://github.com/owner/repo.git"),
("stars", 42),
("forks", 5),
("open_issues", 3),
],
test_mock={
"get_repo_info": lambda *args, **kwargs: {
"name": "repo",
"full_name": "owner/repo",
"description": "A test repo",
"default_branch": "main",
"private": False,
"html_url": "https://github.com/owner/repo",
"clone_url": "https://github.com/owner/repo.git",
"stargazers_count": 42,
"forks_count": 5,
"open_issues_count": 3,
}
},
)
@staticmethod
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
api = get_api(credentials)
response = await api.get(repo_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.get_repo_info(credentials, input_data.repo_url)
yield "name", data["name"]
yield "full_name", data["full_name"]
yield "description", data.get("description", "") or ""
yield "default_branch", data["default_branch"]
yield "private", data["private"]
yield "html_url", data["html_url"]
yield "clone_url", data["clone_url"]
yield "stars", data["stargazers_count"]
yield "forks", data["forks_count"]
yield "open_issues", data["open_issues_count"]
except Exception as e:
yield "error", str(e)
class GithubForkRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to fork",
placeholder="https://github.com/owner/repo",
)
organization: str = SchemaField(
description="Organization to fork into (leave empty to fork to your account)",
default="",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the forked repository")
clone_url: str = SchemaField(description="Git clone URL of the fork")
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
error: str = SchemaField(description="Error message if the fork failed")
def __init__(self):
super().__init__(
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
description="This block forks a GitHub repository to your account or an organization.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubForkRepositoryBlock.Input,
output_schema=GithubForkRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"organization": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/myuser/repo"),
("clone_url", "https://github.com/myuser/repo.git"),
("full_name", "myuser/repo"),
],
test_mock={
"fork_repo": lambda *args, **kwargs: (
"https://github.com/myuser/repo",
"https://github.com/myuser/repo.git",
"myuser/repo",
)
},
)
@staticmethod
async def fork_repo(
credentials: GithubCredentials,
repo_url: str,
organization: str,
) -> tuple[str, str, str]:
api = get_api(credentials)
forks_url = repo_url + "/forks"
data: dict[str, str] = {}
if organization:
data["organization"] = organization
response = await api.post(forks_url, json=data)
result = response.json()
return result["html_url"], result["clone_url"], result["full_name"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, clone_url, full_name = await self.fork_repo(
credentials,
input_data.repo_url,
input_data.organization,
)
yield "url", url
yield "clone_url", clone_url
yield "full_name", full_name
except Exception as e:
yield "error", str(e)
class GithubStarRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to star",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the star operation")
error: str = SchemaField(description="Error message if starring failed")
def __init__(self):
super().__init__(
id="bd700764-53e3-44dd-a969-d1854088458f",
description="This block stars a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubStarRepositoryBlock.Input,
output_schema=GithubStarRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Repository starred successfully")],
test_mock={
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
},
)
@staticmethod
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
api = get_api(credentials, convert_urls=False)
repo_path = github_repo_path(repo_url)
owner, repo = repo_path.split("/")
await api.put(
f"https://api.github.com/user/starred/{owner}/{repo}",
headers={"Content-Length": "0"},
)
return "Repository starred successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.star_repo(credentials, input_data.repo_url)
yield "status", status
except Exception as e:
yield "error", str(e)

View File

@@ -1,452 +0,0 @@
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
per_page: int = SchemaField(
description="Number of branches to return per page (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing branches failed")
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(
branches_url, params={"per_page": str(per_page), "page": str(page)}
)
data = response.json()
repo_path = github_repo_path(repo_url)
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
branches = await self.list_branches(
credentials,
input_data.repo_url,
input_data.per_page,
input_data.page,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
except Exception as e:
yield "error", str(e)
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
is_sensitive_action=True,
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubCompareBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
base: str = SchemaField(
description="Base branch or commit SHA",
placeholder="main",
)
head: str = SchemaField(
description="Head branch or commit SHA to compare against base",
placeholder="feature-branch",
)
class Output(BlockSchemaOutput):
class FileChange(TypedDict):
filename: str
status: str
additions: int
deletions: int
patch: str
status: str = SchemaField(
description="Comparison status: ahead, behind, diverged, or identical"
)
ahead_by: int = SchemaField(
description="Number of commits head is ahead of base"
)
behind_by: int = SchemaField(
description="Number of commits head is behind base"
)
total_commits: int = SchemaField(
description="Total number of commits in the comparison"
)
diff: str = SchemaField(description="Unified diff of all file changes")
file: FileChange = SchemaField(
title="Changed File", description="A changed file with its diff"
)
files: list[FileChange] = SchemaField(
description="List of changed files with their diffs"
)
error: str = SchemaField(description="Error message if comparison failed")
def __init__(self):
super().__init__(
id="2e4faa8c-6086-4546-ba77-172d1d560186",
description="This block compares two branches or commits in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCompareBranchesBlock.Input,
output_schema=GithubCompareBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"base": "main",
"head": "feature",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("status", "ahead"),
("ahead_by", 2),
("behind_by", 0),
("total_commits", 2),
("diff", "+++ b/file.py\n+new line"),
(
"files",
[
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
),
(
"file",
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
},
),
],
test_mock={
"compare_branches": lambda *args, **kwargs: {
"status": "ahead",
"ahead_by": 2,
"behind_by": 0,
"total_commits": 2,
"files": [
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
}
},
)
@staticmethod
async def compare_branches(
credentials: GithubCredentials,
repo_url: str,
base: str,
head: str,
) -> dict:
api = get_api(credentials)
safe_base = quote(base, safe="")
safe_head = quote(head, safe="")
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
response = await api.get(compare_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.compare_branches(
credentials,
input_data.repo_url,
input_data.base,
input_data.head,
)
yield "status", data["status"]
yield "ahead_by", data["ahead_by"]
yield "behind_by", data["behind_by"]
yield "total_commits", data["total_commits"]
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
GithubCompareBranchesBlock.Output.FileChange(
filename=f["filename"],
status=f["status"],
additions=f["additions"],
deletions=f["deletions"],
patch=f.get("patch", ""),
)
for f in data.get("files", [])
]
# Build unified diff
diff_parts = []
for f in data.get("files", []):
patch = f.get("patch", "")
if patch:
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
yield "diff", "\n".join(diff_parts)
yield "files", files
for file in files:
yield "file", file
except Exception as e:
yield "error", str(e)

View File

@@ -1,720 +0,0 @@
import base64
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
error: str = SchemaField(description="Error message if reading the file failed")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = (
repo_url
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
)
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
except Exception as e:
yield "error", str(e)
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to main)",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = (
repo_url
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
)
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
except Exception as e:
yield "error", str(e)
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubSearchCodeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
query: str = SchemaField(
description="Search query (GitHub code search syntax)",
placeholder="className language:python",
)
repo: str = SchemaField(
description="Restrict search to a repository (owner/repo format, optional)",
default="",
placeholder="owner/repo",
)
per_page: int = SchemaField(
description="Number of results to return (max 100)",
default=30,
ge=1,
le=100,
)
class Output(BlockSchemaOutput):
class SearchResult(TypedDict):
name: str
path: str
repository: str
url: str
score: float
result: SearchResult = SchemaField(
title="Result", description="A code search result"
)
results: list[SearchResult] = SchemaField(
description="List of code search results"
)
total_count: int = SchemaField(description="Total number of matching results")
error: str = SchemaField(description="Error message if search failed")
def __init__(self):
super().__init__(
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
description="This block searches for code in GitHub repositories.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSearchCodeBlock.Input,
output_schema=GithubSearchCodeBlock.Output,
test_input={
"query": "addClass",
"repo": "owner/repo",
"per_page": 30,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("total_count", 1),
(
"results",
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
),
(
"result",
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
},
),
],
test_mock={
"search_code": lambda *args, **kwargs: (
1,
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
)
},
)
@staticmethod
async def search_code(
credentials: GithubCredentials,
query: str,
repo: str,
per_page: int,
) -> tuple[int, list[Output.SearchResult]]:
api = get_api(credentials, convert_urls=False)
full_query = f"{query} repo:{repo}" if repo else query
params = {"q": full_query, "per_page": str(per_page)}
response = await api.get("https://api.github.com/search/code", params=params)
data = response.json()
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
GithubSearchCodeBlock.Output.SearchResult(
name=item["name"],
path=item["path"],
repository=item["repository"]["full_name"],
url=item["html_url"],
score=item["score"],
)
for item in data["items"]
]
return data["total_count"], results
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
total_count, results = await self.search_code(
credentials,
input_data.query,
input_data.repo,
input_data.per_page,
)
yield "total_count", total_count
yield "results", results
for result in results:
yield "result", result
except Exception as e:
yield "error", str(e)
class GithubGetRepositoryTreeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to get the tree from",
default="main",
)
recursive: bool = SchemaField(
description="Whether to recursively list the entire tree",
default=True,
)
class Output(BlockSchemaOutput):
class TreeEntry(TypedDict):
path: str
type: str
size: int
sha: str
entry: TreeEntry = SchemaField(
title="Tree Entry", description="A file or directory in the tree"
)
entries: list[TreeEntry] = SchemaField(
description="List of all files and directories in the tree"
)
truncated: bool = SchemaField(
description="Whether the tree was truncated due to size"
)
error: str = SchemaField(description="Error message if getting tree failed")
def __init__(self):
super().__init__(
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
description="This block lists the entire file tree of a GitHub repository recursively.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryTreeBlock.Input,
output_schema=GithubGetRepositoryTreeBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"recursive": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("truncated", False),
(
"entries",
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
),
(
"entry",
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
},
),
],
test_mock={
"get_tree": lambda *args, **kwargs: (
False,
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
)
},
)
@staticmethod
async def get_tree(
credentials: GithubCredentials,
repo_url: str,
branch: str,
recursive: bool,
) -> tuple[bool, list[Output.TreeEntry]]:
api = get_api(credentials)
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
params = {"recursive": "1"} if recursive else {}
response = await api.get(tree_url, params=params)
data = response.json()
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
GithubGetRepositoryTreeBlock.Output.TreeEntry(
path=item["path"],
type=item["type"],
size=item.get("size", 0),
sha=item["sha"],
)
for item in data["tree"]
]
return data.get("truncated", False), entries
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
truncated, entries = await self.get_tree(
credentials,
input_data.repo_url,
input_data.branch,
input_data.recursive,
)
yield "truncated", truncated
yield "entries", entries
for entry in entries:
yield "entry", entry
except Exception as e:
yield "error", str(e)

View File

@@ -1,125 +0,0 @@
import inspect
import pytest
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.data.execution import ExecutionContext
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
class TestPreparePrApiUrl:
def test_https_scheme_preserved(self):
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
assert result == "https://github.com/owner/repo/pulls/42/merge"
def test_http_scheme_preserved(self):
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
assert result == "http://github.com/owner/repo/pulls/1/files"
def test_no_scheme_defaults_to_https(self):
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
assert result == "https://github.com/owner/repo/pulls/5/merge"
def test_reviewers_path(self):
result = prepare_pr_api_url(
"https://github.com/owner/repo/pull/99", "requested_reviewers"
)
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
def test_invalid_url_returned_as_is(self):
url = "https://example.com/not-a-pr"
assert prepare_pr_api_url(url, "merge") == url
def test_empty_string(self):
assert prepare_pr_api_url("", "merge") == ""
# ── Error-path block tests ──
# When a block's run() yields ("error", msg), _execute() converts it to a
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
# which returns early on empty test_output).
def _mock_block(block, mocks: dict):
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
for name, mock_fn in mocks.items():
original = getattr(block, name)
if inspect.iscoroutinefunction(original):
async def async_mock(*args, _fn=mock_fn, **kwargs):
return _fn(*args, **kwargs)
setattr(block, name, async_mock)
else:
setattr(block, name, mock_fn)
def _raise(exc: Exception):
"""Helper that returns a callable which raises the given exception."""
def _raiser(*args, **kwargs):
raise exc
return _raiser
@pytest.mark.asyncio
async def test_merge_pr_error_path():
block = GithubMergePullRequestBlock()
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
input_data = {
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_multi_file_commit_error_path():
block = GithubMultiFileCommitBlock()
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
input_data = {
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "test",
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(
input_data,
credentials=TEST_CREDENTIALS,
execution_context=ExecutionContext(),
):
pass
# ── FileOperation enum tests ──
class TestFileOperation:
def test_upsert_value(self):
assert FileOperation.UPSERT == "upsert"
def test_delete_value(self):
assert FileOperation.DELETE == "delete"
def test_invalid_value_raises(self):
with pytest.raises(ValueError):
FileOperation("create")
def test_invalid_value_raises_typo(self):
with pytest.raises(ValueError):
FileOperation("upser")

View File

@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
h.ignore_links = False
h.ignore_images = True
return h.handle(html_content)
except Exception:
# Keep extraction resilient if html2text is unavailable or fails.
except ImportError:
# Fallback: return raw HTML if html2text is not available
return html_content
# Handle content stored as attachment

View File

@@ -67,7 +67,6 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
@@ -144,11 +143,10 @@ class HITLReviewHelper:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
if is_graph_execution:
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
@@ -170,7 +168,6 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
@@ -200,7 +197,6 @@ class HITLReviewHelper:
graph_version=graph_version,
block_name=block_name,
editable=editable,
is_graph_execution=is_graph_execution,
)
if review_result is None:

View File

@@ -211,7 +211,7 @@ class AgentOutputBlock(Block):
if input_data.format:
try:
formatter = TextFormatter(autoescape=input_data.escape_html)
yield "output", await formatter.format_string(
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
except Exception as e:

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url_host(input_data.url)
parsed_url, _, _ = await validate_url(input_data.url, [])
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,21 +31,10 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.openai_responses import (
convert_tools_to_responses_format,
extract_responses_content,
extract_responses_reasoning,
extract_responses_tool_calls,
extract_responses_usage,
)
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -118,6 +107,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
@@ -126,7 +116,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_4_6_OPUS = "claude-opus-4-6"
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -146,31 +135,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
CODESTRAL = "mistralai/codestral-2508"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
@@ -178,11 +155,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
MICROSOFT_PHI_4 = "microsoft/phi-4"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_3 = "x-ai/grok-3"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
@@ -283,6 +258,9 @@ MODEL_METADATA = {
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
@@ -296,9 +274,6 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
), # claude-opus-4-6
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-6
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
@@ -357,41 +332,17 @@ MODEL_METADATA = {
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 2.5 Pro",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Pro Preview",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3 Flash Preview",
"OpenRouter",
"Google",
1,
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
@@ -399,15 +350,6 @@ MODEL_METADATA = {
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Flash Lite Preview",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
@@ -429,78 +371,12 @@ MODEL_METADATA = {
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
"open_router",
262144,
None,
"Mistral Large 3 2512",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
"open_router",
131072,
None,
"Mistral Medium 3.1",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
"open_router",
131072,
131072,
"Mistral Small 3.2 24B",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.CODESTRAL: ModelMetadata(
"open_router",
256000,
None,
"Codestral 2508",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Translate 08.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
"open_router",
256000,
32768,
"Command A Reasoning 08.2025",
"OpenRouter",
"Cohere",
3,
),
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Vision 07.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
@@ -513,15 +389,6 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
"open_router",
128000,
8000,
"Sonar Reasoning Pro",
"OpenRouter",
"Perplexity",
2,
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
@@ -567,9 +434,6 @@ MODEL_METADATA = {
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
@@ -579,15 +443,6 @@ MODEL_METADATA = {
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_3: ModelMetadata(
"open_router",
131072,
131072,
"Grok 3",
"OpenRouter",
"xAI",
2,
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
@@ -804,53 +659,36 @@ async def llm_call(
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
tools_param = convert_tools_to_responses_format(tools) if tools else openai.omit
parallel_tool_calls = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
text_config = openai.omit
if force_json_output:
text_config = {"format": {"type": "json_object"}} # type: ignore
response_format = {"type": "json_object"}
response = await oai_client.responses.create(
response = await oai_client.chat.completions.create(
model=llm_model.value,
input=prompt, # type: ignore[arg-type]
tools=tools_param, # type: ignore[arg-type]
max_output_tokens=max_tokens,
parallel_tool_calls=get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
),
text=text_config, # type: ignore[arg-type]
store=False,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls,
)
raw_tool_calls = extract_responses_tool_calls(response)
tool_calls = (
[
ToolContentBlock(
id=tc["id"],
type=tc["type"],
function=ToolCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
for tc in raw_tool_calls
]
if raw_tool_calls
else None
)
reasoning = extract_responses_reasoning(response)
content = extract_responses_content(response)
prompt_tokens, completion_tokens = extract_responses_usage(response)
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
return LLMResponse(
raw_response=response,
raw_response=response.choices[0].message,
prompt=prompt,
response=content,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
elif provider == "anthropic":
@@ -962,11 +800,6 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -988,7 +821,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url=OPENROUTER_BASE_URL,
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
@@ -1296,10 +1129,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
values = input_data.prompt_values
if values:
input_data.prompt = await fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = await fmt.format_string(
input_data.sys_prompt, values
)
input_data.prompt = fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
if input_data.sys_prompt:
prompt.append({"role": "system", "content": input_data.sys_prompt})

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr, field_validator
from pydantic import SecretStr
from backend.blocks._base import (
Block,
@@ -13,7 +13,6 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.block import BlockInput
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -22,7 +21,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -36,20 +34,6 @@ class PerplexityModel(str, Enum):
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
if isinstance(value, PerplexityModel):
return value
try:
return PerplexityModel(value)
except ValueError:
logger.warning(
f"Invalid PerplexityModel '{value}', "
f"falling back to {PerplexityModel.SONAR.value}"
)
return PerplexityModel.SONAR
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
@@ -88,25 +72,6 @@ class PerplexityBlock(Block):
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
@field_validator("model", mode="before")
@classmethod
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
"""Fall back to SONAR if the model value is not a valid
PerplexityModel (e.g. an OpenAI model ID set by the agent
generator)."""
return _sanitize_perplexity_model(v)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
"""Sanitize the model field before JSON schema validation so that
invalid values are replaced with the default instead of raising a
BlockInputError."""
model_value = data.get("model")
if model_value is not None:
data["model"] = _sanitize_perplexity_model(model_value).value
return super().validate_data(data)
system_prompt: str = SchemaField(
title="System Prompt",
default="",
@@ -171,7 +136,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url=OPENROUTER_BASE_URL,
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
("post_id", "abc123"),
],
test_mock={"delete_post": lambda creds, post_id: True},
is_sensitive_action=True,
)
@staticmethod
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
("comment_id", "xyz789"),
],
test_mock={"delete_comment": lambda creds, comment_id: True},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
"_convert_to_color": lambda *args, **kwargs: "black",
},
is_sensitive_action=True,
)
async def run(

View File

@@ -61,27 +61,20 @@ class ExecutionParams(BaseModel):
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids = []
# OpenAI Responses API: function_call items have type="function_call"
if entry.get("type") == "function_call":
if call_id := entry.get("call_id"):
tool_call_ids.append(call_id)
return tool_call_ids
if entry.get("role") != "assistant":
return tool_call_ids
# OpenAI Chat Completions: check for tool_calls in the entry.
# OpenAI: check for tool_calls in the entry.
calls = entry.get("tool_calls")
if isinstance(calls, list):
for call in calls:
if tool_id := call.get("id"):
tool_call_ids.append(tool_id)
# Anthropic: check content items for tool_use type.
# Anthropics: check content items for tool_use type.
content = entry.get("content")
if isinstance(content, list):
for item in content:
@@ -96,22 +89,16 @@ def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool response.
Supports OpenAI Chat Completions, Responses API, and Anthropic formats.
Supports both OpenAI and Anthropics formats.
"""
tool_call_ids: list[str] = []
# OpenAI Responses API: function_call_output items
if entry.get("type") == "function_call_output":
if call_id := entry.get("call_id"):
tool_call_ids.append(str(call_id))
return tool_call_ids
# OpenAI Chat Completions: a tool response message with role "tool".
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
if entry.get("role") == "tool":
if tool_call_id := entry.get("tool_call_id"):
tool_call_ids.append(str(tool_call_id))
# Anthropic: check content items for tool_result type.
# Anthropics: check content items for tool_result type.
if entry.get("role") == "user":
content = entry.get("content")
if isinstance(content, list):
@@ -124,16 +111,14 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
return tool_call_ids
def _create_tool_response(
call_id: str, output: Any, *, responses_api: bool = False
) -> dict[str, Any]:
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
"""
Create a tool response message for OpenAI, Anthropic, or OpenAI Responses API,
based on the tool_id format and the responses_api flag.
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
"""
content = output if isinstance(output, str) else json.dumps(output)
# Anthropic format: tool IDs typically start with "toolu_"
# Anthropics format: tool IDs typically start with "toolu_"
if call_id.startswith("toolu_"):
return {
"role": "user",
@@ -143,11 +128,8 @@ def _create_tool_response(
],
}
# OpenAI Responses API format
if responses_api:
return {"type": "function_call_output", "call_id": call_id, "output": content}
# OpenAI Chat Completions format (default fallback)
# OpenAI format: tool IDs typically start with "call_".
# Or default fallback (if the tool_id doesn't match any known prefix)
return {"role": "tool", "tool_call_id": call_id, "content": content}
@@ -195,19 +177,10 @@ def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str
return tool_outputs
def _convert_raw_response_to_dict(
raw_response: Any,
) -> dict[str, Any] | list[dict[str, Any]]:
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
"""
Safely convert raw_response to dictionary format for conversation history.
Handles different response types from different LLM providers.
For the OpenAI Responses API, the raw_response is the entire Response
object. Its ``output`` items (messages, function_calls) are extracted
individually so they can be used as valid input items on the next call.
Returns a **list** of dicts in that case.
For Chat Completions / Anthropic / Ollama, returns a single dict.
"""
if isinstance(raw_response, str):
# Ollama returns a string, convert to dict format
@@ -215,28 +188,11 @@ def _convert_raw_response_to_dict(
elif isinstance(raw_response, dict):
# Already a dict (from tests or some providers)
return raw_response
elif _is_responses_api_object(raw_response):
# OpenAI Responses API: extract individual output items
items = [json.to_dict(item) for item in raw_response.output]
return items if items else [{"role": "assistant", "content": ""}]
else:
# Chat Completions / Anthropic return message objects
# OpenAI/Anthropic return objects, convert with json.to_dict
return json.to_dict(raw_response)
def _is_responses_api_object(obj: Any) -> bool:
"""Detect an OpenAI Responses API Response object.
These have ``object == "response"`` and an ``output`` list, but no
``role`` attribute (unlike ChatCompletionMessage).
"""
return (
getattr(obj, "object", None) == "response"
and hasattr(obj, "output")
and not hasattr(obj, "role")
)
def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
@@ -798,34 +754,19 @@ class SmartDecisionMakerBlock(Block):
self, prompt: list[dict], response, tool_outputs: list | None = None
):
"""Update conversation history with response and tool outputs."""
converted = _convert_raw_response_to_dict(response.raw_response)
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
assistant_message = _convert_raw_response_to_dict(response.raw_response)
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
item.get("type") == "tool_use"
for item in assistant_message.get("content", [])
)
if isinstance(converted, list):
# Responses API: output items are already individual dicts
has_tool_calls = any(
item.get("type") == "function_call" for item in converted
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
if response.reasoning and not has_tool_calls:
prompt.append(
{
"role": "assistant",
"content": f"[Reasoning]: {response.reasoning}",
}
)
prompt.extend(converted)
else:
# Chat Completions / Anthropic: single assistant message dict
has_tool_calls = isinstance(converted.get("content"), list) and any(
item.get("type") == "tool_use" for item in converted.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{
"role": "assistant",
"content": f"[Reasoning]: {response.reasoning}",
}
)
prompt.append(converted)
prompt.append(assistant_message)
if tool_outputs:
prompt.extend(tool_outputs)
@@ -835,8 +776,6 @@ class SmartDecisionMakerBlock(Block):
tool_info: ToolInfo,
execution_params: ExecutionParams,
execution_processor: "ExecutionProcessor",
*,
responses_api: bool = False,
) -> dict:
"""Execute a single tool using the execution manager for proper integration."""
# Lazy imports to avoid circular dependencies
@@ -929,17 +868,13 @@ class SmartDecisionMakerBlock(Block):
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(
tool_call.id, tool_response_content, responses_api=responses_api
)
return _create_tool_response(tool_call.id, tool_response_content)
except Exception as e:
logger.error(f"Tool execution with manager failed: {e}")
# Return error response
return _create_tool_response(
tool_call.id,
f"Tool execution failed: {str(e)}",
responses_api=responses_api,
tool_call.id, f"Tool execution failed: {str(e)}"
)
async def _execute_tools_agent_mode(
@@ -960,7 +895,6 @@ class SmartDecisionMakerBlock(Block):
"""Execute tools in agent mode with a loop until finished."""
max_iterations = input_data.agent_mode_max_iterations
iteration = 0
use_responses_api = input_data.model.metadata.provider == "openai"
# Execution parameters for tool execution
execution_params = ExecutionParams(
@@ -1017,19 +951,14 @@ class SmartDecisionMakerBlock(Block):
for tool_info in processed_tools:
try:
tool_response = await self._execute_single_tool_with_manager(
tool_info,
execution_params,
execution_processor,
responses_api=use_responses_api,
tool_info, execution_params, execution_processor
)
tool_outputs.append(tool_response)
except Exception as e:
logger.error(f"Tool execution failed: {e}")
# Create error response for the tool
error_response = _create_tool_response(
tool_info.tool_call.id,
f"Error: {str(e)}",
responses_api=use_responses_api,
tool_info.tool_call.id, f"Error: {str(e)}"
)
tool_outputs.append(error_response)
@@ -1091,17 +1020,11 @@ class SmartDecisionMakerBlock(Block):
if pending_tool_calls and input_data.last_tool_output is None:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
use_responses_api = input_data.model.metadata.provider == "openai"
tool_output = []
if pending_tool_calls and input_data.last_tool_output is not None:
first_call_id = next(iter(pending_tool_calls.keys()))
tool_output.append(
_create_tool_response(
first_call_id,
input_data.last_tool_output,
responses_api=use_responses_api,
)
_create_tool_response(first_call_id, input_data.last_tool_output)
)
prompt.extend(tool_output)
@@ -1127,15 +1050,11 @@ class SmartDecisionMakerBlock(Block):
values = input_data.prompt_values
if values:
input_data.prompt = await llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = await llm.fmt.format_string(
input_data.sys_prompt, values
)
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
if input_data.sys_prompt and not any(
p.get("role") == "system"
and isinstance(p.get("content"), str)
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
):
prompt.append(
@@ -1146,9 +1065,7 @@ class SmartDecisionMakerBlock(Block):
)
if input_data.prompt and not any(
p.get("role") == "user"
and isinstance(p.get("content"), str)
and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
):
prompt.append(
@@ -1256,26 +1173,11 @@ class SmartDecisionMakerBlock(Block):
)
yield emit_key, arg_value
converted = _convert_raw_response_to_dict(response.raw_response)
# Check for tool calls to avoid inserting reasoning between tool pairs
if isinstance(converted, list):
has_tool_calls = any(
item.get("type") == "function_call" for item in converted
)
else:
has_tool_calls = isinstance(converted.get("content"), list) and any(
item.get("type") == "tool_use" for item in converted.get("content", [])
)
if response.reasoning and not has_tool_calls:
if response.reasoning:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
if isinstance(converted, list):
prompt.extend(converted)
else:
prompt.append(converted)
prompt.append(_convert_raw_response_to_dict(response.raw_response))
yield "conversations", prompt

View File

@@ -83,8 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
@property
def provider_name(self) -> str:
@@ -138,7 +137,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -228,7 +227,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -325,7 +324,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()

View File

@@ -1,8 +1,8 @@
import logging
from typing import Literal
from pydantic import BaseModel
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks._base import (
Block,
BlockCategory,
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
category: str | None = SchemaField(
description="Filter by category", default=None
)
sort_by: StoreAgentsSortOptions = SchemaField(
description="How to sort the results", default=StoreAgentsSortOptions.RATING
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
description="How to sort the results", default="rating"
)
limit: int = SchemaField(
description="Maximum number of results to return", default=10, ge=1, le=100
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
self,
query: str | None = None,
category: str | None = None,
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
limit: int = 10,
) -> SearchAgentsResponse:
"""

View File

@@ -1,223 +0,0 @@
"""Tests for AutoPilotBlock: recursion guard, streaming, validation, and error paths."""
import asyncio
from unittest.mock import AsyncMock
import pytest
from backend.blocks.autopilot import (
AUTOPILOT_BLOCK_ID,
AutoPilotBlock,
_autopilot_recursion_depth,
_autopilot_recursion_limit,
_check_recursion,
_reset_recursion,
)
from backend.data.execution import ExecutionContext
def _make_context(user_id: str = "test-user-123") -> ExecutionContext:
"""Helper to build an ExecutionContext for tests."""
return ExecutionContext(
user_id=user_id,
graph_id="graph-1",
graph_exec_id="gexec-1",
graph_version=1,
node_id="node-1",
node_exec_id="nexec-1",
)
# ---------------------------------------------------------------------------
# Recursion guard unit tests
# ---------------------------------------------------------------------------
class TestCheckRecursion:
"""Unit tests for _check_recursion / _reset_recursion."""
def test_first_call_increments_depth(self):
tokens = _check_recursion(3)
try:
assert _autopilot_recursion_depth.get() == 1
assert _autopilot_recursion_limit.get() == 3
finally:
_reset_recursion(tokens)
def test_reset_restores_previous_values(self):
assert _autopilot_recursion_depth.get() == 0
assert _autopilot_recursion_limit.get() is None
tokens = _check_recursion(5)
_reset_recursion(tokens)
assert _autopilot_recursion_depth.get() == 0
assert _autopilot_recursion_limit.get() is None
def test_exceeding_limit_raises(self):
t1 = _check_recursion(2)
try:
t2 = _check_recursion(2)
try:
with pytest.raises(RuntimeError, match="recursion depth limit"):
_check_recursion(2)
finally:
_reset_recursion(t2)
finally:
_reset_recursion(t1)
def test_nested_calls_respect_inherited_limit(self):
"""Inner call with higher max_depth still respects outer limit."""
t1 = _check_recursion(2) # sets limit=2
try:
t2 = _check_recursion(10) # inner wants 10, but inherited is 2
try:
# depth is now 2, limit is min(10, 2) = 2 → should raise
with pytest.raises(RuntimeError, match="recursion depth limit"):
_check_recursion(10)
finally:
_reset_recursion(t2)
finally:
_reset_recursion(t1)
def test_limit_of_one_blocks_immediately_on_second_call(self):
t1 = _check_recursion(1)
try:
with pytest.raises(RuntimeError):
_check_recursion(1)
finally:
_reset_recursion(t1)
# ---------------------------------------------------------------------------
# AutoPilotBlock.run() validation tests
# ---------------------------------------------------------------------------
class TestRunValidation:
"""Tests for input validation in AutoPilotBlock.run()."""
@pytest.fixture
def block(self):
return AutoPilotBlock()
@pytest.mark.asyncio
async def test_empty_prompt_yields_error(self, block):
block.Input # ensure schema is accessible
input_data = block.Input(prompt=" ", max_recursion_depth=3)
ctx = _make_context()
outputs = {}
async for name, value in block.run(input_data, execution_context=ctx):
outputs[name] = value
assert outputs.get("error") == "Prompt cannot be empty."
assert "response" not in outputs
@pytest.mark.asyncio
async def test_missing_user_id_yields_error(self, block):
input_data = block.Input(prompt="hello", max_recursion_depth=3)
ctx = _make_context(user_id="")
outputs = {}
async for name, value in block.run(input_data, execution_context=ctx):
outputs[name] = value
assert "authenticated user" in outputs.get("error", "")
@pytest.mark.asyncio
async def test_successful_run_yields_all_outputs(self, block):
"""With execute_copilot mocked, run() should yield all 5 success outputs."""
mock_result = (
"Hello world",
[],
'[{"role":"user","content":"hi"}]',
"sess-abc",
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
block.execute_copilot = AsyncMock(return_value=mock_result)
block.create_session = AsyncMock(return_value="sess-abc")
input_data = block.Input(prompt="hi", max_recursion_depth=3)
ctx = _make_context()
outputs = {}
async for name, value in block.run(input_data, execution_context=ctx):
outputs[name] = value
assert outputs["response"] == "Hello world"
assert outputs["tool_calls"] == []
assert outputs["session_id"] == "sess-abc"
assert outputs["token_usage"]["total_tokens"] == 15
assert "error" not in outputs
@pytest.mark.asyncio
async def test_exception_yields_error(self, block):
"""On unexpected failure, run() should yield an error output."""
block.execute_copilot = AsyncMock(side_effect=RuntimeError("boom"))
block.create_session = AsyncMock(return_value="sess-fail")
input_data = block.Input(prompt="do something", max_recursion_depth=3)
ctx = _make_context()
outputs = {}
async for name, value in block.run(input_data, execution_context=ctx):
outputs[name] = value
assert outputs["session_id"] == "sess-fail"
assert "boom" in outputs.get("error", "")
@pytest.mark.asyncio
async def test_cancelled_error_yields_error_and_reraises(self, block):
"""CancelledError should yield error, then re-raise."""
block.execute_copilot = AsyncMock(side_effect=asyncio.CancelledError())
block.create_session = AsyncMock(return_value="sess-cancel")
input_data = block.Input(prompt="do something", max_recursion_depth=3)
ctx = _make_context()
outputs = {}
with pytest.raises(asyncio.CancelledError):
async for name, value in block.run(input_data, execution_context=ctx):
outputs[name] = value
assert outputs["session_id"] == "sess-cancel"
assert "cancelled" in outputs.get("error", "").lower()
@pytest.mark.asyncio
async def test_existing_session_id_skips_create(self, block):
"""When session_id is provided, create_session should not be called."""
mock_result = (
"ok",
[],
"[]",
"existing-sid",
{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
block.execute_copilot = AsyncMock(return_value=mock_result)
block.create_session = AsyncMock()
input_data = block.Input(
prompt="test", session_id="existing-sid", max_recursion_depth=3
)
ctx = _make_context()
async for _ in block.run(input_data, execution_context=ctx):
pass
block.create_session.assert_not_called()
# ---------------------------------------------------------------------------
# Block registration / ID tests
# ---------------------------------------------------------------------------
class TestBlockRegistration:
def test_block_id_matches_constant(self):
block = AutoPilotBlock()
assert block.id == AUTOPILOT_BLOCK_ID
def test_max_recursion_depth_has_upper_bound(self):
"""Schema should enforce le=10."""
schema = AutoPilotBlock.Input.model_json_schema()
max_rec = schema["properties"]["max_recursion_depth"]
assert (
max_rec.get("maximum") == 10 or max_rec.get("exclusiveMaximum", 999) <= 11
)
def test_output_schema_has_no_duplicate_error_field(self):
"""Output should inherit error from BlockSchemaOutput, not redefine it."""
# The field should exist (inherited) but there should be no explicit
# redefinition. We verify by checking the class __annotations__ directly.
assert "error" not in AutoPilotBlock.Output.__annotations__

View File

@@ -13,17 +13,18 @@ class TestLLMStatsTracking:
"""Test that llm_call returns proper token counts in LLMResponse."""
import backend.blocks.llm as llm
# Mock the OpenAI Responses API response
# Mock the OpenAI client
mock_response = MagicMock()
mock_response.output_text = "Test response"
mock_response.output = []
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
mock_response.choices = [
MagicMock(message=MagicMock(content="Test response", tool_calls=None))
]
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20)
# Test with mocked OpenAI response
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_client.responses.create = AsyncMock(return_value=mock_response)
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
response = await llm.llm_call(
credentials=llm.TEST_CREDENTIALS,
@@ -270,17 +271,30 @@ class TestLLMStatsTracking:
mock_response = MagicMock()
# Return different responses for chunk summary vs final summary
if call_count == 1:
mock_response.output_text = '<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>'
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
tool_calls=None,
)
)
]
else:
mock_response.output_text = '<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>'
mock_response.output = []
mock_response.usage = MagicMock(input_tokens=50, output_tokens=30)
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
tool_calls=None,
)
)
]
mock_response.usage = MagicMock(prompt_tokens=50, completion_tokens=30)
return mock_response
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_client.responses.create = mock_create
mock_client.chat.completions.create = mock_create
# Test with very short text (should only need 1 chunk + 1 final summary)
input_data = llm.AITextSummarizerBlock.Input(

View File

@@ -1,81 +0,0 @@
"""Unit tests for PerplexityBlock model fallback behavior."""
import pytest
from backend.blocks.perplexity import (
TEST_CREDENTIALS_INPUT,
PerplexityBlock,
PerplexityModel,
)
def _make_input(**overrides) -> dict:
defaults = {
"prompt": "test query",
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return defaults
class TestPerplexityModelFallback:
"""Tests for fallback_invalid_model field_validator."""
def test_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
assert inp.model == PerplexityModel.SONAR
def test_another_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
assert inp.model == PerplexityModel.SONAR
def test_valid_model_string_is_kept(self):
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
assert inp.model == PerplexityModel.SONAR_PRO
def test_valid_enum_value_is_kept(self):
inp = PerplexityBlock.Input(
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
)
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
def test_default_model_when_omitted(self):
inp = PerplexityBlock.Input(**_make_input())
assert inp.model == PerplexityModel.SONAR
@pytest.mark.parametrize(
"model_value",
[
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
)
def test_all_valid_models_accepted(self, model_value: str):
inp = PerplexityBlock.Input(**_make_input(model=model_value))
assert inp.model.value == model_value
class TestPerplexityValidateData:
"""Tests for validate_data which runs during block execution (before
Pydantic instantiation). Invalid models must be sanitized here so
JSON schema validation does not reject them."""
def test_invalid_model_sanitized_before_schema_validation(self):
data = _make_input(model="gpt-5.2-2025-12-11")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == PerplexityModel.SONAR.value
def test_valid_model_unchanged_by_validate_data(self):
data = _make_input(model="perplexity/sonar-pro")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == "perplexity/sonar-pro"
def test_missing_model_uses_default(self):
data = _make_input() # no model key
error = PerplexityBlock.Input.validate_data(data)
assert error is None
inp = PerplexityBlock.Input(**data)
assert inp.model == PerplexityModel.SONAR

View File

@@ -2,7 +2,6 @@ from unittest.mock import MagicMock
import pytest
from backend.api.features.store.db import StoreAgentsSortOptions
from backend.blocks.system.library_operations import (
AddToLibraryFromStoreBlock,
LibraryAgent,
@@ -122,10 +121,7 @@ async def test_search_store_agents_block(mocker):
)
input_data = block.Input(
query="test",
category="productivity",
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
limit=10,
query="test", category="productivity", sort_by="rating", limit=10
)
outputs = {}

View File

@@ -290,9 +290,7 @@ class FillTextTemplateBlock(Block):
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
formatter = text.TextFormatter(autoescape=input_data.escape_html)
yield "output", await formatter.format_string(
input_data.format, input_data.values
)
yield "output", formatter.format_string(input_data.format, input_data.values)
class CombineTextsBlock(Block):

View File

@@ -22,7 +22,6 @@ from backend.copilot.model import (
update_session_title,
upsert_chat_session,
)
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
@@ -36,15 +35,13 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
_get_openai_client,
client,
config,
)
from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
@@ -65,8 +62,8 @@ async def _update_title_async(
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title and user_id:
await update_session_title(session_id, user_id, title, only_if_empty=True)
if title:
await update_session_title(session_id, title)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
@@ -91,7 +88,7 @@ async def _compress_session_messages(
result = await compress_context(
messages=messages_dict,
model=config.model,
client=_get_openai_client(),
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
@@ -179,17 +176,14 @@ async def stream_chat_completion_baseline(
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
base_system_prompt, _ = await _build_system_prompt(
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
base_system_prompt, _ = await _build_system_prompt(
system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
@@ -223,10 +217,6 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -238,31 +228,16 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
@@ -415,7 +390,6 @@ async def stream_chat_completion_baseline(
)
except Exception as e:
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
@@ -433,49 +407,6 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
# NOTE: This estimates one round's prompt tokens. Multi-round tool-calling
# turns consume prompt tokens on each API call, so the total is underestimated.
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
):
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
# cannot break out cache_read/cache_creation weights. Users on the
# baseline path may be slightly over-counted vs the SDK path.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -486,16 +417,4 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -1,13 +1,10 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -22,7 +19,7 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
default="https://openrouter.ai/api/v1",
description="Base URL for API (e.g., for OpenRouter)",
)
@@ -70,27 +67,6 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Per-turn token cost varies with context size: ~10-15K for early turns,
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
# allows ~70-100 turns/day.
# Checked at the HTTP layer (routes.py) before each turn.
#
# TODO: These are deploy-time constants applied identically to every user.
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
# must move to the database (e.g., a UserPlan table) and get_usage_status /
# check_rate_limit would look up each user's specific limits instead of
# reading config.daily_token_limit / config.weekly_token_limit.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
@@ -115,22 +91,10 @@ class ChatConfig(BaseSettings):
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
use_openrouter: bool = Field(
default=True,
description="Enable routing API calls through the OpenRouter proxy. "
"The actual decision also requires ``api_key`` and ``base_url`` — "
"use the ``openrouter_active`` property for the final answer.",
)
use_claude_code_subscription: bool = Field(
default=False,
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
)
test_mode: bool = Field(
default=False,
description="Use dummy service instead of real LLM calls. "
"Send __test_transient_error__, __test_fatal_error__, or "
"__test_slow_response__ to trigger specific scenarios.",
)
# E2B Sandbox Configuration
use_e2b_sandbox: bool = Field(
@@ -148,52 +112,18 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=420, # 7 min safety net — allows headroom for compaction retries
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
)
@property
def openrouter_active(self) -> bool:
"""True when OpenRouter is enabled AND credentials are usable.
Single source of truth for "will the SDK route through OpenRouter?".
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
present — mirrors the fallback logic in ``_build_sdk_env``.
"""
if not self.use_openrouter:
return False
base = (self.base_url or "").rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
return bool(self.api_key and base and base.startswith("http"))
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
def get_use_e2b_sandbox(cls, v):
"""Get use_e2b_sandbox from environment if not provided."""
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return True if v is None else v
@field_validator("e2b_api_key", mode="before")
@classmethod
@@ -234,9 +164,29 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = OPENROUTER_BASE_URL
v = "https://openrouter.ai/api/v1"
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
@field_validator("use_claude_code_subscription", mode="before")
@classmethod
def get_use_claude_code_subscription(cls, v):
"""Get use_claude_code_subscription from environment if not provided."""
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return False if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
@@ -246,7 +196,6 @@ class ChatConfig(BaseSettings):
class Config:
"""Pydantic config."""
env_prefix = "CHAT_"
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore" # Ignore extra environment variables

View File

@@ -1,89 +0,0 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_ENV_VARS_TO_CLEAR = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
"CHAT_USE_OPENROUTER",
"CHAT_API_KEY",
"OPEN_ROUTER_API_KEY",
"OPENAI_API_KEY",
"CHAT_BASE_URL",
"OPENROUTER_BASE_URL",
"OPENAI_BASE_URL",
)
@pytest.fixture(autouse=True)
def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _ENV_VARS_TO_CLEAR:
monkeypatch.delenv(var, raising=False)
class TestOpenrouterActive:
"""Tests for the openrouter_active property."""
def test_enabled_with_credentials_returns_true(self):
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
)
assert cfg.openrouter_active is True
def test_enabled_but_missing_api_key_returns_false(self):
cfg = ChatConfig(
use_openrouter=True,
api_key=None,
base_url="https://openrouter.ai/api/v1",
)
assert cfg.openrouter_active is False
def test_disabled_returns_false_despite_credentials(self):
cfg = ChatConfig(
use_openrouter=False,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
)
assert cfg.openrouter_active is False
def test_strips_v1_suffix_and_still_valid(self):
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
)
assert cfg.openrouter_active is True
def test_invalid_base_url_returns_false(self):
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="not-a-url",
)
assert cfg.openrouter_active is False
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

Some files were not shown because too many files have changed in this diff Show More