Compare commits

..

1 Commits

Author SHA1 Message Date
Swifty
af7e8d19fe feat(platform): add beta invite provisioning 2026-03-09 16:38:23 +01:00
286 changed files with 6114 additions and 30266 deletions

View File

@@ -1,79 +0,0 @@
---
name: pr-address
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
user-invocable: true
args: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Address
## Find the PR
```bash
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
gh pr view {N}
```
## Fetch comments (all sources)
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
```
**Bots to watch for:**
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
- `coderabbitai[bot]` — automated review. Address actionable items.
## For each unaddressed comment
Address comments **one at a time**: fix → commit → push → inline reply → next.
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
2. Commit and push the fix
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
| Comment type | How to reply |
|---|---|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
## Format and commit
After fixing, format the changed code:
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
If API routes changed, regenerate the frontend client:
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
trap "kill $REST_PID 2>/dev/null" EXIT
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID 2>/dev/null; trap - EXIT
```
Never manually edit files in `src/app/api/__generated__/`.
Then commit and **push immediately** — never batch commits without pushing.
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
## The loop
```text
address comments → format → commit → push
→ re-check comments → fix new ones → push
→ wait for CI → re-check comments after CI settles
→ repeat until: all comments addressed AND CI green AND no new comments arriving
```
While CI runs, stay productive: run local tests, address remaining comments.
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.

View File

@@ -1,74 +0,0 @@
---
name: pr-review
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
user-invocable: true
args: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review
## Find the PR
```bash
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
gh pr view {N}
```
## Read the diff
```bash
gh pr diff {N}
```
## Fetch existing review comments
Before posting anything, fetch existing inline comments to avoid duplicates:
```bash
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
```
## What to check
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
## Output format
Every comment **must** be prefixed with `🤖` and a criticality badge:
| Tier | Badge | Meaning |
|---|---|---|
| Blocker | `🔴 **Blocker**` | Must fix before merge |
| Should Fix | `🟠 **Should Fix**` | Important improvement |
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
| Nit | `🔵 **Nit**` | Style / wording |
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
## Post inline comments
For each finding, post an inline comment on the PR (do not just write a local report):
```bash
# Get the latest commit SHA for the PR
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
# Post an inline comment on a specific file/line
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
-f body="🤖 🔴 **Blocker**: <description>" \
-f commit_id="$COMMIT_SHA" \
-f path="<file path>" \
-F line=<line number>
```

View File

@@ -1,85 +0,0 @@
---
name: worktree
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
metadata:
author: autogpt-team
version: "3.0.0"
---
# Worktree Setup
## Create the worktree
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
```bash
ROOT=$(git rev-parse --show-toplevel)
PARENT=$(dirname "$ROOT")
# From an existing branch
git worktree add "$PARENT/<NAME>" <branch-name>
# From a new branch off dev
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
```
## Copy environment files
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
```bash
ROOT=$(git rev-parse --show-toplevel)
TARGET="$(dirname "$ROOT")/<NAME>"
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
if [ -f "$ROOT/$envpath/.env" ]; then
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
elif [ -f "$ROOT/$envpath/.env.default" ]; then
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
fi
done
```
## Install dependencies
```bash
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
cd "$TARGET/autogpt_platform/frontend" && pnpm install
```
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
## Running the app (optional)
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
```bash
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd "$TARGET/autogpt_platform/backend" && poetry run app
```
## CoPilot testing
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
## Cleanup
```bash
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
```
## Alternative: Branchlet (optional)
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
```bash
branchlet create -n <name> -s <source-branch> -b <new-branch>
```

View File

@@ -5,14 +5,12 @@ on:
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
pull_request:
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
merge_group:

View File

@@ -120,6 +120,175 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs
integration_test:
runs-on: ubuntu-latest
needs: setup

View File

@@ -1,18 +1,14 @@
name: AutoGPT Platform - Full-stack CI
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
merge_group:
@@ -28,28 +24,42 @@ defaults:
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Enable corepack
run: corepack enable
- name: Set up Node
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install dependencies to populate cache
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
check-api-types:
name: check API types
runs-on: ubuntu-latest
types:
runs-on: big-boi
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
@@ -57,256 +67,70 @@ jobs:
with:
submodules: recursive
# ------------------------ Backend setup ------------------------
- name: Set up Backend - Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Set up Backend - Install Poetry
working-directory: autogpt_platform/backend
run: |
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Installing Poetry version ${POETRY_VERSION}"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
- name: Set up Backend - Set up dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Set up Backend - Install dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Set up Backend - Generate Prisma client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
- name: Set up Frontend - Export OpenAPI schema from Backend
working-directory: autogpt_platform/backend
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
# ------------------------ Frontend setup ------------------------
- name: Set up Frontend - Enable corepack
run: corepack enable
- name: Set up Frontend - Set up Node
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up Frontend - Install dependencies
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
- name: Restore dependencies cache
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up Frontend - Format OpenAPI schema
id: format-schema
run: pnpm prettier --write ./src/app/api/openapi.json
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after exporting the API schema."
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "\nIn the backend directory:"
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
echo "\nIn the frontend directory:"
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
echo "3. Run 'pnpm generate:api'"
echo "4. Run 'pnpm types'"
echo "5. Fix any TypeScript errors that may have been introduced"
echo "6. Commit and push your changes"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Set up Frontend - Generate API client
id: generate-api-client
run: pnpm orval --config ./orval.config.ts
# Continue with type generation & check even if there are schema changes
if: success() || (steps.format-schema.outcome == 'success')
- name: Check for TypeScript errors
- name: Run Typescript checks
run: pnpm types
if: success() || (steps.generate-api-client.outcome == 'success')
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs

View File

@@ -60,12 +60,9 @@ AutoGPT Platform is a monorepo containing:
### Reviewing/Revising Pull Requests
Use `/pr-review` to review a PR or `/pr-address` to address comments.
When fetching comments manually:
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits

View File

@@ -1,40 +0,0 @@
-- =============================================================
-- View: analytics.auth_activities
-- Looker source alias: ds49 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Tracks authentication events (login, logout, SSO, password
-- reset, etc.) from Supabase's internal audit log.
-- Useful for monitoring sign-in patterns and detecting anomalies.
--
-- SOURCE TABLES
-- auth.audit_log_entries — Supabase internal auth event log
--
-- OUTPUT COLUMNS
-- created_at TIMESTAMPTZ When the auth event occurred
-- actor_id TEXT User ID who triggered the event
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
--
-- WINDOW
-- Rolling 90 days from current date
--
-- EXAMPLE QUERIES
-- -- Daily login counts
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
-- FROM analytics.auth_activities
-- WHERE action = 'login'
-- GROUP BY 1 ORDER BY 1;
--
-- -- SSO vs password login breakdown
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
-- WHERE action = 'login' GROUP BY 1;
-- =============================================================
SELECT
created_at,
payload->>'actor_id' AS actor_id,
payload->>'actor_via_sso' AS actor_via_sso,
payload->>'action' AS action
FROM auth.audit_log_entries
WHERE created_at >= NOW() - INTERVAL '90 days'

View File

@@ -1,105 +0,0 @@
-- =============================================================
-- View: analytics.graph_execution
-- Looker source alias: ds16 | Charts: 21
-- =============================================================
-- DESCRIPTION
-- One row per agent graph execution (last 90 days).
-- Unpacks the JSONB stats column into individual numeric columns
-- and normalises the executionStatus — runs that failed due to
-- insufficient credits are reclassified as 'NO_CREDITS' for
-- easier filtering. Error messages are scrubbed of IDs and URLs
-- to allow safe grouping.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
-- platform.AgentGraph — Agent graph metadata (for name)
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
--
-- OUTPUT COLUMNS
-- id TEXT Execution UUID
-- agentGraphId TEXT Agent graph UUID
-- agentGraphVersion INT Graph version number
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
-- createdAt TIMESTAMPTZ When the execution was queued
-- updatedAt TIMESTAMPTZ Last status update time
-- userId TEXT Owner user UUID
-- agentGraphName TEXT Human-readable agent name
-- cputime DECIMAL Total CPU seconds consumed
-- walltime DECIMAL Total wall-clock seconds
-- node_count DECIMAL Number of nodes in the graph
-- nodes_cputime DECIMAL CPU time across all nodes
-- nodes_walltime DECIMAL Wall time across all nodes
-- execution_cost DECIMAL Credit cost of this execution
-- correctness_score FLOAT AI correctness score (if available)
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Daily execution counts by status
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
-- FROM analytics.graph_execution
-- GROUP BY 1, 2 ORDER BY 1;
--
-- -- Average cost per execution by agent
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'COMPLETED'
-- GROUP BY 1 ORDER BY avg_cost DESC;
--
-- -- Top error messages
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
-- FROM analytics.graph_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
-- =============================================================
SELECT
ge."id" AS id,
ge."agentGraphId" AS agentGraphId,
ge."agentGraphVersion" AS agentGraphVersion,
CASE
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
AND (
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
)
THEN 'NO_CREDITS'
ELSE CAST(ge."executionStatus" AS TEXT)
END AS executionStatus,
ge."createdAt" AS createdAt,
ge."updatedAt" AS updatedAt,
ge."userId" AS userId,
g."name" AS agentGraphName,
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage
FROM platform."AgentGraphExecution" ge
LEFT JOIN platform."AgentGraph" g
ON ge."agentGraphId" = g."id"
AND ge."agentGraphVersion" = g."version"
LEFT JOIN (
SELECT DISTINCT ON ("userId", "agentGraphId")
"userId", "agentGraphId",
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
FROM platform."LibraryAgent"
WHERE "isDeleted" = FALSE
AND "isArchived" = FALSE
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,101 +0,0 @@
-- =============================================================
-- View: analytics.node_block_execution
-- Looker source alias: ds14 | Charts: 11
-- =============================================================
-- DESCRIPTION
-- One row per node (block) execution (last 90 days).
-- Unpacks stats JSONB and joins to identify which block type
-- was run. For failed nodes, joins the error output and
-- scrubs it for safe grouping.
--
-- SOURCE TABLES
-- platform.AgentNodeExecution — Node execution records
-- platform.AgentNode — Node → block mapping
-- platform.AgentBlock — Block name/ID
-- platform.AgentNodeExecutionInputOutput — Error output values
--
-- OUTPUT COLUMNS
-- id TEXT Node execution UUID
-- agentGraphExecutionId TEXT Parent graph execution UUID
-- agentNodeId TEXT Node UUID within the graph
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
-- addedTime TIMESTAMPTZ When the node was queued
-- queuedTime TIMESTAMPTZ When it entered the queue
-- startedTime TIMESTAMPTZ When execution started
-- endedTime TIMESTAMPTZ When execution finished
-- inputSize BIGINT Input payload size in bytes
-- outputSize BIGINT Output payload size in bytes
-- walltime NUMERIC Wall-clock seconds for this node
-- cputime NUMERIC CPU seconds for this node
-- llmRetryCount INT Number of LLM retries
-- llmCallCount INT Number of LLM API calls made
-- inputTokenCount BIGINT LLM input tokens consumed
-- outputTokenCount BIGINT LLM output tokens produced
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
-- blockId TEXT Block UUID
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
-- errorMessage TEXT Raw error output (only set when FAILED)
--
-- WINDOW
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Most-used blocks by execution count
-- SELECT "blockName", COUNT(*) AS executions,
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
-- FROM analytics.node_block_execution
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
--
-- -- Average LLM token usage per block
-- SELECT "blockName",
-- AVG("inputTokenCount") AS avg_input_tokens,
-- AVG("outputTokenCount") AS avg_output_tokens
-- FROM analytics.node_block_execution
-- WHERE "llmCallCount" > 0
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
--
-- -- Top failure reasons
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
-- FROM analytics.node_block_execution
-- WHERE "executionStatus" = 'FAILED'
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
-- =============================================================
SELECT
ne."id" AS id,
ne."agentGraphExecutionId" AS agentGraphExecutionId,
ne."agentNodeId" AS agentNodeId,
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
ne."addedTime" AS addedTime,
ne."queuedTime" AS queuedTime,
ne."startedTime" AS startedTime,
ne."endedTime" AS endedTime,
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
b."name" AS blockName,
b."id" AS blockId,
REGEXP_REPLACE(
REGEXP_REPLACE(
TRIM(BOTH '"' FROM eio."data"::text),
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
'\1\2/...', 'gi'
),
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
) AS groupedErrorMessage,
eio."data" AS errorMessage
FROM platform."AgentNodeExecution" ne
LEFT JOIN platform."AgentNode" nd
ON ne."agentNodeId" = nd."id"
LEFT JOIN platform."AgentBlock" b
ON nd."agentBlockId" = b."id"
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
ON eio."referencedByOutputExecId" = ne."id"
AND eio."name" = 'error'
AND ne."executionStatus" = 'FAILED'
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,97 +0,0 @@
-- =============================================================
-- View: analytics.retention_agent
-- Looker source alias: ds35 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention broken down per individual agent.
-- Cohort = week of a user's first use of THAT specific agent.
-- Tells you which agents keep users coming back vs. one-shot
-- use. Only includes cohorts from the last 180 days.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records (user × agent × time)
-- platform.AgentGraph — Agent names
--
-- OUTPUT COLUMNS
-- agent_id TEXT Agent graph UUID
-- agent_label TEXT 'AgentName [first8chars]'
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
-- cohort_week_start DATE Week users first ran this agent
-- cohort_label TEXT ISO week label
-- cohort_label_n TEXT ISO week label with cohort size
-- user_lifetime_week INT Weeks since first use of this agent
-- cohort_users BIGINT Users in this cohort for this agent
-- active_users BIGINT Users who ran the agent again in week k
-- retention_rate FLOAT active_users / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
-- agent_total_users BIGINT Total users across all cohorts for this agent
--
-- EXAMPLE QUERIES
-- -- Best-retained agents at week 2
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
-- FROM analytics.retention_agent
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
--
-- -- Agents with most unique users
-- SELECT DISTINCT agent_label, agent_total_users
-- FROM analytics.retention_agent
-- ORDER BY agent_total_users DESC LIMIT 20;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e
),
first_use AS (
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1,2
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
),
active_counts AS (
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
),
cohort_sizes AS (
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
),
cohort_caps AS (
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
),
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
SELECT
g.agent_id,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(ac.active_users,0) AS active_users,
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
COALESCE(atu.agent_total_users,0) AS agent_total_users
FROM grid g
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,81 +0,0 @@
-- =============================================================
-- View: analytics.retention_execution_daily
-- Looker source alias: ds111 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on agent executions.
-- Cohort anchor = day of user's FIRST ever execution.
-- Only includes cohorts from the last 90 days, up to day 30.
-- Great for early engagement analysis (did users run another
-- agent the next day?).
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_daily.
-- cohort_day_start = day of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Day-3 execution retention
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
-- FROM analytics.retention_execution_daily
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('day', e."createdAt")::date AS day_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fe.cohort_day_start,
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_exec fe USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -1,81 +0,0 @@
-- =============================================================
-- View: analytics.retention_execution_weekly
-- Looker source alias: ds92 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on agent executions.
-- Cohort anchor = week of user's FIRST ever agent execution
-- (not first login). Only includes cohorts from the last 180 days.
-- Useful when you care about product engagement, not just visits.
--
-- SOURCE TABLES
-- platform.AgentGraphExecution — Execution records
--
-- OUTPUT COLUMNS
-- Same pattern as retention_login_weekly.
-- cohort_week_start = week of first execution (not first login)
--
-- EXAMPLE QUERIES
-- -- Week-2 execution retention
-- SELECT cohort_label, retention_rate_bounded
-- FROM analytics.retention_execution_weekly
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
events AS (
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
DATE_TRUNC('week', e."createdAt")::date AS week_start
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
),
first_exec AS (
SELECT user_id, MIN(created_at) AS first_exec_at,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fe.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,94 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_daily
-- Looker source alias: ds112 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Daily cohort retention based on login sessions.
-- Same logic as retention_login_weekly but at day granularity,
-- showing up to day 30 for cohorts from the last 90 days.
-- Useful for analysing early activation (days 1-7) in detail.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
-- cohort_day_start DATE First day the cohort logged in
-- cohort_label TEXT Date string (e.g. '2025-03-01')
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
-- user_lifetime_day INT Days since first login (0 = signup day)
-- cohort_users BIGINT Total users in cohort
-- active_users_bounded BIGINT Users active on exactly day k
-- retained_users_unbounded BIGINT Users active any time on/after day k
-- retention_rate_bounded FLOAT bounded / cohort_users
-- retention_rate_unbounded FLOAT unbounded / cohort_users
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
--
-- EXAMPLE QUERIES
-- -- Day-1 retention rate (came back next day)
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
-- FROM analytics.retention_login_daily
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
--
-- -- Average retention curve across all cohorts
-- SELECT user_lifetime_day,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
-- FROM analytics.retention_login_daily
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('day', s.created_at)::date AS day_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
FROM events GROUP BY 1
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
),
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
user_day_age AS (
SELECT ad.user_id, fl.cohort_day_start,
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
FROM activity_days ad JOIN first_login fl USING (user_id)
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_day_start, cs.cohort_users,
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
)
SELECT
g.cohort_day_start,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_day, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
ORDER BY g.cohort_day_start, g.user_lifetime_day;

View File

@@ -1,96 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_onboarded_weekly
-- Looker source alias: ds101 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention from login sessions, restricted to
-- users who "onboarded" — defined as running at least one
-- agent within 365 days of their first login.
-- Filters out users who signed up but never activated,
-- giving a cleaner view of engaged-user retention.
--
-- SOURCE TABLES
-- auth.sessions — Login session records
-- platform.AgentGraphExecution — Used to identify onboarders
--
-- OUTPUT COLUMNS
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
-- retention_rate_bounded, retention_rate_unbounded, etc.)
-- Only difference: cohort is filtered to onboarded users only.
--
-- EXAMPLE QUERIES
-- -- Compare week-4 retention: all users vs onboarded only
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
-- UNION ALL
-- SELECT 'onboarded', AVG(retention_rate_bounded)
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login_all AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
onboarders AS (
SELECT fl.user_id FROM first_login_all fl
WHERE EXISTS (
SELECT 1 FROM platform."AgentGraphExecution" e
WHERE e."userId"::text = fl.user_id
AND e."createdAt" >= fl.first_login_time
AND e."createdAt" < fl.first_login_time
+ make_interval(days => (SELECT onboarding_window_days FROM params))
)
),
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week;

View File

@@ -1,103 +0,0 @@
-- =============================================================
-- View: analytics.retention_login_weekly
-- Looker source alias: ds83 | Charts: 2
-- =============================================================
-- DESCRIPTION
-- Weekly cohort retention based on login sessions.
-- Users are grouped by the ISO week of their first ever login.
-- For each cohort × lifetime-week combination, outputs both:
-- - bounded rate: % active in exactly that week
-- - unbounded rate: % who were ever active on or after that week
-- Weeks are capped to the cohort's actual age (no future data points).
--
-- SOURCE TABLES
-- auth.sessions — Login session records
--
-- HOW TO READ THE OUTPUT
-- cohort_week_start The Monday of the week users first logged in
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
-- retention_rate_bounded = active_users_bounded / cohort_users
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
--
-- OUTPUT COLUMNS
-- cohort_week_start DATE First day of the cohort's signup week
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
-- user_lifetime_week INT Weeks since first login (0 = signup week)
-- cohort_users BIGINT Total users in this cohort (denominator)
-- active_users_bounded BIGINT Users active in exactly week k
-- retained_users_unbounded BIGINT Users active any time on/after week k
-- retention_rate_bounded FLOAT bounded active / cohort_users
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
--
-- EXAMPLE QUERIES
-- -- Week-1 retention rate per cohort
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
-- FROM analytics.retention_login_weekly
-- WHERE user_lifetime_week = 1
-- ORDER BY cohort_week_start;
--
-- -- Overall average retention curve (all cohorts combined)
-- SELECT user_lifetime_week,
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
-- FROM analytics.retention_login_weekly
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
WITH params AS (SELECT 12::int AS max_weeks),
events AS (
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
DATE_TRUNC('week', s.created_at)::date AS week_start
FROM auth.sessions s WHERE s.user_id IS NOT NULL
),
first_login AS (
SELECT user_id, MIN(created_at) AS first_login_time,
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
FROM events GROUP BY 1
),
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
user_week_age AS (
SELECT aw.user_id, fl.cohort_week_start,
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
FROM activity_weeks aw JOIN first_login fl USING (user_id)
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
),
bounded_counts AS (
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
),
last_active AS (
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
),
unbounded_counts AS (
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
FROM last_active la
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
GROUP BY 1,2
),
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
cohort_caps AS (
SELECT cs.cohort_week_start, cs.cohort_users,
LEAST((SELECT max_weeks FROM params),
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
FROM cohort_sizes cs
),
grid AS (
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
)
SELECT
g.cohort_week_start,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
g.user_lifetime_week, g.cohort_users,
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
FROM grid g
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
ORDER BY g.cohort_week_start, g.user_lifetime_week

View File

@@ -1,71 +0,0 @@
-- =============================================================
-- View: analytics.user_block_spending
-- Looker source alias: ds6 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per credit transaction (last 90 days).
-- Shows how users spend credits broken down by block type,
-- LLM provider and model. Joins node execution stats for
-- token-level detail.
--
-- SOURCE TABLES
-- platform.CreditTransaction — Credit debit/credit records
-- platform.AgentNodeExecution — Node execution stats (for token counts)
--
-- OUTPUT COLUMNS
-- transactionKey TEXT Unique transaction identifier
-- userId TEXT User who was charged
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
-- transactionTime TIMESTAMPTZ When the transaction was recorded
-- blockId TEXT Block UUID that triggered the spend
-- blockName TEXT Human-readable block name
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
-- node_exec_id TEXT Linked node execution UUID
-- llm_call_count INT LLM API calls made in that execution
-- llm_retry_count INT LLM retries in that execution
-- llm_input_token_count INT Input tokens consumed
-- llm_output_token_count INT Output tokens produced
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend per user (last 90 days)
-- SELECT "userId", SUM("negativeAmount") AS total_spent
-- FROM analytics.user_block_spending
-- WHERE "transactionType" = 'USAGE'
-- GROUP BY 1 ORDER BY total_spent DESC;
--
-- -- Spend by LLM provider + model
-- SELECT "llm_provider", "llm_model",
-- SUM("negativeAmount") AS total_cost,
-- SUM("llm_input_token_count") AS input_tokens,
-- SUM("llm_output_token_count") AS output_tokens
-- FROM analytics.user_block_spending
-- WHERE "llm_provider" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
-- =============================================================
SELECT
c."transactionKey" AS transactionKey,
c."userId" AS userId,
c."amount" AS amount,
c."amount" * -1 AS negativeAmount,
c."type" AS transactionType,
c."createdAt" AS transactionTime,
c.metadata->>'block_id' AS blockId,
c.metadata->>'block' AS blockName,
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
c.metadata->'input'->>'model' AS llm_model,
c.metadata->>'node_exec_id' AS node_exec_id,
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
FROM platform."CreditTransaction" c
LEFT JOIN platform."AgentNodeExecution" ne
ON (c.metadata->>'node_exec_id') = ne."id"::text
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -1,45 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding
-- Looker source alias: ds68 | Charts: 3
-- =============================================================
-- DESCRIPTION
-- One row per user onboarding record. Contains the user's
-- stated usage reason, selected integrations, completed
-- onboarding steps and optional first agent selection.
-- Full history (no date filter) since onboarding happens
-- once per user.
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding state per user
--
-- OUTPUT COLUMNS
-- id TEXT Onboarding record UUID
-- createdAt TIMESTAMPTZ When onboarding started
-- updatedAt TIMESTAMPTZ Last update to onboarding state
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
-- integrations TEXT[] Array of integration names the user selected
-- userId TEXT User UUID
-- completedSteps TEXT[] Array of onboarding step enums completed
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
--
-- EXAMPLE QUERIES
-- -- Usage reason breakdown
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
--
-- -- Completion rate per step
-- SELECT step, COUNT(*) AS users_completed
-- FROM analytics.user_onboarding
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
-- GROUP BY 1 ORDER BY users_completed DESC;
-- =============================================================
SELECT
id,
"createdAt",
"updatedAt",
"usageReason",
integrations,
"userId",
"completedSteps",
"selectedStoreListingVersionId"
FROM platform."UserOnboarding"

View File

@@ -1,100 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding_funnel
-- Looker source alias: ds74 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated onboarding funnel showing how many users
-- completed each step and the drop-off percentage from the
-- previous step. One row per onboarding step (all 22 steps
-- always present, even with 0 completions — prevents sparse
-- gaps from making LAG compare the wrong predecessors).
--
-- SOURCE TABLES
-- platform.UserOnboarding — Onboarding records with completedSteps array
--
-- OUTPUT COLUMNS
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
-- step_order INT Numeric position in the funnel (1=first, 22=last)
-- users_completed BIGINT Distinct users who completed this step
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
--
-- STEP ORDER
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
-- 7 CONGRATS 15 VISIT_COPILOT
-- 8 GET_RESULTS 16 RE_RUN_AGENT
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full funnel
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
--
-- -- Biggest drop-off point
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
-- ORDER BY pct_from_prev ASC LIMIT 3;
-- =============================================================
WITH all_steps AS (
-- Complete ordered grid of all 22 steps so zero-completion steps
-- are always present, keeping LAG comparisons correct.
SELECT step_name, step_order
FROM (VALUES
('WELCOME', 1),
('USAGE_REASON', 2),
('INTEGRATIONS', 3),
('AGENT_CHOICE', 4),
('AGENT_NEW_RUN', 5),
('AGENT_INPUT', 6),
('CONGRATS', 7),
('GET_RESULTS', 8),
('MARKETPLACE_VISIT', 9),
('MARKETPLACE_ADD_AGENT', 10),
('MARKETPLACE_RUN_AGENT', 11),
('BUILDER_OPEN', 12),
('BUILDER_SAVE_AGENT', 13),
('BUILDER_RUN_AGENT', 14),
('VISIT_COPILOT', 15),
('RE_RUN_AGENT', 16),
('SCHEDULE_AGENT', 17),
('RUN_AGENTS', 18),
('RUN_3_DAYS', 19),
('TRIGGER_WEBHOOK', 20),
('RUN_14_DAYS', 21),
('RUN_AGENTS_100', 22)
) AS t(step_name, step_order)
),
raw AS (
SELECT
u."userId",
step_txt::text AS step
FROM platform."UserOnboarding" u
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
),
step_counts AS (
SELECT step, COUNT(DISTINCT "userId") AS users_completed
FROM raw GROUP BY step
),
funnel AS (
SELECT
a.step_name AS step,
a.step_order,
COALESCE(sc.users_completed, 0) AS users_completed,
ROUND(
100.0 * COALESCE(sc.users_completed, 0)
/ NULLIF(
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
0
),
2
) AS pct_from_prev
FROM all_steps a
LEFT JOIN step_counts sc ON sc.step = a.step_name
)
SELECT * FROM funnel ORDER BY step_order

View File

@@ -1,41 +0,0 @@
-- =============================================================
-- View: analytics.user_onboarding_integration
-- Looker source alias: ds75 | Charts: 1
-- =============================================================
-- DESCRIPTION
-- Pre-aggregated count of users who selected each integration
-- during onboarding. One row per integration type, sorted
-- by popularity.
--
-- SOURCE TABLES
-- platform.UserOnboarding — integrations array column
--
-- OUTPUT COLUMNS
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
-- users_with_integration BIGINT Distinct users who selected this integration
--
-- WINDOW
-- Users who started onboarding in the last 90 days
--
-- EXAMPLE QUERIES
-- -- Full integration popularity ranking
-- SELECT * FROM analytics.user_onboarding_integration;
--
-- -- Top 5 integrations
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
-- =============================================================
WITH exploded AS (
SELECT
u."userId" AS user_id,
UNNEST(u."integrations") AS integration
FROM platform."UserOnboarding" u
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
)
SELECT
integration,
COUNT(DISTINCT user_id) AS users_with_integration
FROM exploded
WHERE integration IS NOT NULL AND integration <> ''
GROUP BY integration
ORDER BY users_with_integration DESC

View File

@@ -1,145 +0,0 @@
-- =============================================================
-- View: analytics.users_activities
-- Looker source alias: ds56 | Charts: 5
-- =============================================================
-- DESCRIPTION
-- One row per user with lifetime activity summary.
-- Joins login sessions with agent graphs, executions and
-- node-level runs to give a full picture of how engaged
-- each user is. Includes a convenience flag for 7-day
-- activation (did the user return at least 7 days after
-- their first login?).
--
-- SOURCE TABLES
-- auth.sessions — Login/session records
-- platform.AgentGraph — Graphs (agents) built by the user
-- platform.AgentGraphExecution — Agent run history
-- platform.AgentNodeExecution — Individual block execution history
--
-- PERFORMANCE NOTE
-- Each CTE aggregates its own table independently by userId.
-- This avoids the fan-out that occurs when driving every join
-- from user_logins across the two largest tables
-- (AgentGraphExecution and AgentNodeExecution).
--
-- OUTPUT COLUMNS
-- user_id TEXT Supabase user UUID
-- first_login_time TIMESTAMPTZ First ever session created_at
-- last_login_time TIMESTAMPTZ Most recent session created_at
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
-- agent_runs BIGINT Total graph execution count (0 if none)
-- node_execution_count BIGINT Total node executions across all runs
-- node_execution_failed BIGINT Node executions with FAILED status
-- node_execution_completed BIGINT Node executions with COMPLETED status
-- node_execution_terminated BIGINT Node executions with TERMINATED status
-- node_execution_queued BIGINT Node executions with QUEUED status
-- node_execution_running BIGINT Node executions with RUNNING status
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
-- node_execution_review BIGINT Node executions with REVIEW status
--
-- EXAMPLE QUERIES
-- -- Users who ran at least one agent and returned after 7 days
-- SELECT COUNT(*) FROM analytics.users_activities
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
--
-- -- Top 10 most active users by agent runs
-- SELECT user_id, agent_runs, node_execution_count
-- FROM analytics.users_activities
-- ORDER BY agent_runs DESC LIMIT 10;
--
-- -- 7-day activation rate
-- SELECT
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
-- AS activation_rate
-- FROM analytics.users_activities;
-- =============================================================
WITH user_logins AS (
SELECT
user_id::text AS user_id,
MIN(created_at) AS first_login_time,
MAX(created_at) AS last_login_time,
GREATEST(
MAX(refreshed_at)::timestamptz,
MAX(created_at)::timestamptz
) AS last_visit_time
FROM auth.sessions
GROUP BY user_id
),
user_agents AS (
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
SELECT
"userId"::text AS user_id,
MAX("updatedAt") AS last_agent_save_time,
COUNT(DISTINCT "id") AS agent_count
FROM platform."AgentGraph"
WHERE "isActive"
GROUP BY "userId"
),
user_graph_runs AS (
-- Aggregate AgentGraphExecution directly by userId
SELECT
"userId"::text AS user_id,
MIN("createdAt") AS first_agent_run_time,
MAX("createdAt") AS last_agent_run_time,
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
COUNT("id") AS agent_runs
FROM platform."AgentGraphExecution"
GROUP BY "userId"
),
user_node_runs AS (
-- Aggregate AgentNodeExecution directly; resolve userId via a
-- single join to AgentGraphExecution instead of fanning out from
-- user_logins through both large tables.
SELECT
g."userId"::text AS user_id,
COUNT(*) AS node_execution_count,
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
FROM platform."AgentNodeExecution" n
JOIN platform."AgentGraphExecution" g
ON g."id" = n."agentGraphExecutionId"
GROUP BY g."userId"
)
SELECT
ul.user_id,
ul.first_login_time,
ul.last_login_time,
ul.last_visit_time,
ua.last_agent_save_time,
COALESCE(ua.agent_count, 0) AS agent_count,
gr.first_agent_run_time,
gr.last_agent_run_time,
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
COALESCE(gr.agent_runs, 0) AS agent_runs,
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
CASE
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
ELSE NULL
END AS is_active_after_7d,
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
COALESCE(nr.node_execution_review, 0) AS node_execution_review
FROM user_logins ul
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id

View File

@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
## ===== SIGNUP / INVITE GATE ===== ##
# Set to true to require an invite before users can sign up
ENABLE_INVITE_GATE=false
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000

View File

@@ -58,31 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Code Style
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
- **Pydantic models** over dataclass/namedtuple/dict for structured data
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
- **List comprehensions** over manual loop-and-append
- **Early return** — guard clauses first, avoid deep nesting
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
- **`max(0, value)` guards** — for computed values that should never be negative
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
- After refactoring, update mock targets to match new module paths
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
## Database Schema
@@ -178,16 +157,6 @@ yield "image_url", result_url
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
@@ -9,9 +7,6 @@ from pydantic import BaseModel, EmailStr
from backend.data.model import UserTransaction
from backend.util.models import Pagination
if TYPE_CHECKING:
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
@@ -43,14 +38,9 @@ class InvitedUserResponse(BaseModel):
created_at: datetime
updated_at: datetime
@classmethod
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
return cls.model_validate(record.model_dump())
class InvitedUsersResponse(BaseModel):
invited_users: list[InvitedUserResponse]
pagination: Pagination
class BulkInvitedUserRowResponse(BaseModel):
@@ -67,26 +57,3 @@ class BulkInvitedUsersResponse(BaseModel):
skipped_count: int
error_count: int
results: list[BulkInvitedUserRowResponse]
@classmethod
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return cls(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
InvitedUserResponse.from_record(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)

View File

@@ -1,20 +1,20 @@
import logging
import math
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, Query, Security, UploadFile
from fastapi import APIRouter, File, Security, UploadFile
from backend.data.invited_user import (
BulkInvitedUsersResult,
InvitedUserRecord,
bulk_create_invited_users_from_file,
create_invited_user,
list_invited_users,
retry_invited_user_tally,
revoke_invited_user,
)
from backend.data.tally import mask_email
from backend.util.models import Pagination
from .model import (
BulkInvitedUserRowResponse,
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
@@ -31,6 +31,33 @@ router = APIRouter(
)
def _to_response(invited_user: InvitedUserRecord) -> InvitedUserResponse:
return InvitedUserResponse(**invited_user.model_dump())
def _to_bulk_response(result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
return BulkInvitedUsersResponse(
created_count=result.created_count,
skipped_count=result.skipped_count,
error_count=result.error_count,
results=[
BulkInvitedUserRowResponse(
row_number=row.row_number,
email=row.email,
name=row.name,
status=row.status,
message=row.message,
invited_user=(
_to_response(row.invited_user)
if row.invited_user is not None
else None
),
)
for row in result.results
],
)
@router.get(
"/invited-users",
response_model=InvitedUsersResponse,
@@ -38,19 +65,11 @@ router = APIRouter(
)
async def get_invited_users(
admin_user_id: str = Security(get_user_id),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
) -> InvitedUsersResponse:
logger.info("Admin user %s requested invited users", admin_user_id)
invited_users, total = await list_invited_users(page=page, page_size=page_size)
invited_users = await list_invited_users()
return InvitedUsersResponse(
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
pagination=Pagination(
total_items=total,
total_pages=max(1, math.ceil(total / page_size)),
current_page=page,
page_size=page_size,
),
invited_users=[_to_response(invited_user) for invited_user in invited_users]
)
@@ -64,17 +83,12 @@ async def create_invited_user_route(
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s creating invited user for %s",
"Admin user %s created invited user for %s",
admin_user_id,
mask_email(request.email),
request.email,
)
invited_user = await create_invited_user(request.email, request.name)
logger.info(
"Admin user %s created invited user %s",
admin_user_id,
invited_user.id,
)
return InvitedUserResponse.from_record(invited_user)
return _to_response(invited_user)
@router.post(
@@ -94,7 +108,7 @@ async def bulk_create_invited_users_route(
)
content = await file.read()
result = await bulk_create_invited_users_from_file(file.filename, content)
return BulkInvitedUsersResponse.from_result(result)
return _to_bulk_response(result)
@router.post(
@@ -106,12 +120,9 @@ async def revoke_invited_user_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
)
invited_user = await revoke_invited_user(invited_user_id)
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
return InvitedUserResponse.from_record(invited_user)
invited_user = await revoke_invited_user(invited_user_id)
return _to_response(invited_user)
@router.post(
@@ -123,15 +134,10 @@ async def retry_invited_user_tally_route(
invited_user_id: str,
admin_user_id: str = Security(get_user_id),
) -> InvitedUserResponse:
logger.info(
"Admin user %s retrying Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
invited_user = await retry_invited_user_tally(invited_user_id)
logger.info(
"Admin user %s retried Tally seed for invited user %s",
admin_user_id,
invited_user_id,
)
return InvitedUserResponse.from_record(invited_user)
invited_user = await retry_invited_user_tally(invited_user_id)
return _to_response(invited_user)

View File

@@ -77,7 +77,7 @@ def test_get_invited_users(
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.list_invited_users",
AsyncMock(return_value=([_sample_invited_user()], 1)),
AsyncMock(return_value=[_sample_invited_user()]),
)
response = client.get("/admin/invited-users")
@@ -87,9 +87,6 @@ def test_get_invited_users(
assert len(data["invited_users"]) == 1
assert data["invited_users"][0]["email"] == "invited@example.com"
assert data["invited_users"][0]["status"] == "INVITED"
assert data["pagination"]["total_items"] == 1
assert data["pagination"]["current_page"] == 1
assert data["pagination"]["page_size"] == 50
def test_create_invited_user(

View File

@@ -28,7 +28,6 @@ from backend.copilot.model import (
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -53,8 +52,6 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
@@ -129,7 +126,6 @@ class SessionSummaryResponse(BaseModel):
created_at: str
updated_at: str
title: str | None = None
is_processing: bool
class ListSessionsResponse(BaseModel):
@@ -188,28 +184,6 @@ async def list_sessions(
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
# Batch-check Redis for active stream status on each session
processing_set: set[str] = set()
if sessions:
try:
redis = await get_redis_async()
pipe = redis.pipeline(transaction=False)
for session in sessions:
pipe.hget(
f"{config.session_meta_prefix}{session.session_id}",
"status",
)
statuses = await pipe.execute()
processing_set = {
session.session_id
for session, st in zip(sessions, statuses)
if st == "running"
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; " "defaulting to empty"
)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
@@ -217,7 +191,6 @@ async def list_sessions(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
is_processing=session.session_id in processing_set,
)
for session in sessions
],
@@ -292,12 +265,12 @@ async def delete_session(
)
# Best-effort cleanup of the E2B sandbox (if any).
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
try:
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
await kill_sandbox(session_id, config.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
@@ -854,36 +827,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Suggested Prompts ==========
class SuggestedPromptsResponse(BaseModel):
"""Response model for user-specific suggested prompts."""
prompts: list[str]
@router.get(
"/suggested-prompts",
dependencies=[Security(auth.requires_user)],
)
async def get_suggested_prompts(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> SuggestedPromptsResponse:
"""
Get LLM-generated suggested prompts for the authenticated user.
Returns personalized quick-action prompts based on the user's
business understanding. Returns an empty list if no custom prompts
are available.
"""
understanding = await get_business_understanding(user_id)
if understanding is None:
return SuggestedPromptsResponse(prompts=[])
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
# ========== Configuration ==========

View File

@@ -1,6 +1,6 @@
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
"""Tests for chat API routes: session title update and file attachment validation."""
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -249,62 +249,3 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False
# ─── Suggested prompts endpoint ──────────────────────────────────────
def _mock_get_business_understanding(
mocker: pytest_mock.MockerFixture,
*,
return_value=None,
):
"""Mock get_business_understanding."""
return mocker.patch(
"backend.api.features.chat.routes.get_business_understanding",
new_callable=AsyncMock,
return_value=return_value,
)
def test_suggested_prompts_returns_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding and prompts gets them back."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
def test_suggested_prompts_no_understanding(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with no understanding gets empty list."""
_mock_get_business_understanding(mocker, return_value=None)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}
def test_suggested_prompts_empty_prompts(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""User with understanding but no prompts gets empty list."""
mock_understanding = MagicMock()
mock_understanding.suggested_prompts = []
_mock_get_business_understanding(mocker, return_value=mock_understanding)
response = client.get("/suggested-prompts")
assert response.status_code == 200
assert response.json() == {"prompts": []}

View File

@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "test_node_123"
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
# Mock get_node_executions to return node_id mapping
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
mock_node_exec.node_exec_id = "node_exec_approved"
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
# Mock get_node_executions to return batch node data
mock_get_node_executions = mocker.patch(
"backend.api.features.executions.review.routes.get_node_executions"
"backend.data.execution.get_node_executions"
)
# Create mock node executions for each review
mock_node_execs = []

View File

@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, HTTPException, Query, Security, status
from prisma.enums import ReviewStatus
from backend.copilot.constants import (
is_copilot_synthetic_id,
parse_node_id_from_exec_id,
)
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
get_graph_execution_meta,
get_node_executions,
)
from backend.data.graph import get_graph_settings
from backend.data.human_review import (
@@ -41,38 +36,6 @@ router = APIRouter(
)
async def _resolve_node_ids(
node_exec_ids: list[str],
graph_exec_id: str,
is_copilot: bool,
) -> dict[str, str]:
"""Resolve node_exec_id -> node_id for auto-approval records.
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
Graph executions look up node_id from NodeExecution records.
"""
if not node_exec_ids:
return {}
if is_copilot:
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
result = {}
for neid in node_exec_ids:
if neid in node_exec_map:
result[neid] = node_exec_map[neid]
else:
logger.error(
f"Failed to resolve node_id for {neid}: Node execution not found."
)
return result
@router.get(
"/pending",
summary="Get Pending Reviews",
@@ -147,16 +110,14 @@ async def list_pending_reviews_for_execution(
"""
# Verify user owns the graph execution before returning reviews
# (CoPilot synthetic IDs don't have graph execution records)
if not is_copilot_synthetic_id(graph_exec_id):
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
graph_exec = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if not graph_exec:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
@@ -199,26 +160,30 @@ async def process_review_action(
)
graph_exec_id = next(iter(graph_exec_ids))
is_copilot = is_copilot_synthetic_id(graph_exec_id)
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
if not is_copilot:
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
# Validate execution status before processing reviews
graph_exec_meta = await get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
# Only allow processing reviews if execution is paused for review
# or incomplete (partial execution with some reviews already processed)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
f"Reviews can only be processed when execution is paused (REVIEW status). "
f"Current status: {graph_exec_meta.status}",
)
if not graph_exec_meta:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Graph execution #{graph_exec_id} not found",
)
if graph_exec_meta.status not in (
ExecutionStatus.REVIEW,
ExecutionStatus.INCOMPLETE,
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
)
# Build review decisions map and track which reviews requested auto-approval
# Auto-approved reviews use original data (no modifications allowed)
@@ -271,7 +236,7 @@ async def process_review_action(
)
return (node_id, False)
# Collect node_exec_ids that need auto-approval and resolve their node_ids
# Collect node_exec_ids that need auto-approval
node_exec_ids_needing_auto_approval = [
node_exec_id
for node_exec_id, review_result in updated_reviews.items()
@@ -279,16 +244,29 @@ async def process_review_action(
and auto_approve_requests.get(node_exec_id, False)
]
node_id_map = await _resolve_node_ids(
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
)
# Deduplicate by node_id — one auto-approval per node
# Batch-fetch node executions to get node_ids
nodes_needing_auto_approval: dict[str, Any] = {}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_id = node_id_map.get(node_exec_id)
if node_id and node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
if node_exec_ids_needing_auto_approval:
from backend.data.execution import get_node_executions
node_execs = await get_node_executions(
graph_exec_id=graph_exec_id, include_exec_data=False
)
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
for node_exec_id in node_exec_ids_needing_auto_approval:
node_exec = node_exec_map.get(node_exec_id)
if node_exec:
review_result = updated_reviews[node_exec_id]
# Use the first approved review for this node (deduplicate by node_id)
if node_exec.node_id not in nodes_needing_auto_approval:
nodes_needing_auto_approval[node_exec.node_id] = review_result
else:
logger.error(
f"Failed to create auto-approval record for {node_exec_id}: "
f"Node execution not found. This may indicate a race condition "
f"or data inconsistency."
)
# Execute all auto-approval creations in parallel (deduplicated by node_id)
auto_approval_results = await asyncio.gather(
@@ -303,11 +281,13 @@ async def process_review_action(
auto_approval_failed_count = 0
for result in auto_approval_results:
if isinstance(result, Exception):
# Unexpected exception during auto-approval creation
auto_approval_failed_count += 1
logger.error(
f"Unexpected exception during auto-approval creation: {result}"
)
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
# Auto-approval creation failed (returned False)
auto_approval_failed_count += 1
# Count results
@@ -322,20 +302,22 @@ async def process_review_action(
if review.status == ReviewStatus.REJECTED
)
# Resume graph execution only for real graph executions (not CoPilot)
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
if not is_copilot and updated_reviews:
# Resume execution only if ALL pending reviews for this execution have been processed
if updated_reviews:
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
if not still_has_pending:
# Get the graph_id from any processed review
first_review = next(iter(updated_reviews.values()))
try:
# Fetch user and settings to build complete execution context
user = await get_user_by_id(user_id)
settings = await get_graph_settings(
user_id=user_id, graph_id=first_review.graph_id
)
# Preserve user's timezone preference when resuming execution
user_timezone = (
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)

View File

@@ -165,6 +165,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
image_url: str | None
@@ -205,9 +206,7 @@ class LibraryAgent(pydantic.BaseModel):
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool = pydantic.Field(
description="Indicates whether the same user owns the corresponding graph"
)
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
@@ -325,6 +324,7 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -129,6 +131,7 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -181,6 +184,7 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url_host
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -80,7 +80,7 @@ async def discover_tools(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url_host(auth_server_url)
await validate_url(auth_server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url_host(registration_endpoint)
await validate_url(registration_endpoint, trusted_origins=[])
except ValueError:
pass # Skip registration, fall back to default client_id
else:
@@ -429,7 +429,7 @@ async def mcp_store_token(
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url_host(request.server_url)
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")

View File

@@ -32,9 +32,9 @@ async def client():
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
):
yield
@@ -521,12 +521,12 @@ class TestStoreToken:
class TestSSRFValidation:
"""Verify that validate_url_host is enforced on all endpoints."""
"""Verify that validate_url is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
@@ -541,7 +541,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
@@ -556,7 +556,7 @@ class TestSSRFValidation:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url_host",
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):

View File

@@ -165,7 +165,9 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
async def update_user_email_route(
user_id: Annotated[str, Security(get_user_id)],
email: str = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> dict[str, str]:
await get_or_activate_user(user_data)
await update_user_email(user_id, email)
return {"email": email}
@@ -178,16 +180,10 @@ async def update_user_email_route(
dependencies=[Security(requires_user)],
)
async def get_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Get user timezone setting."""
try:
user = await get_user_by_id(user_id)
except ValueError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail="User not found. Please complete activation via /auth/user first.",
)
user = await get_or_activate_user(user_data)
return TimezoneResponse(timezone=user.timezone)
@@ -200,8 +196,10 @@ async def get_user_timezone_route(
async def update_user_timezone_route(
user_id: Annotated[str, Security(get_user_id)],
request: UpdateTimezoneRequest,
user_data: dict = Security(get_jwt_payload),
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
await get_or_activate_user(user_data)
user = await update_user_timezone(user_id, str(request.timezone))
return TimezoneResponse(timezone=user.timezone)
@@ -214,7 +212,9 @@ async def update_user_timezone_route(
)
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
preferences = await get_user_notification_preference(user_id)
return preferences
@@ -228,7 +228,9 @@ async def get_preferences(
async def update_preferences(
user_id: Annotated[str, Security(get_user_id)],
preferences: NotificationPreferenceDTO = Body(...),
user_data: dict = Security(get_jwt_payload),
) -> NotificationPreference:
await get_or_activate_user(user_data)
output = await update_user_notification_preference(user_id, preferences)
return output

View File

@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
class OnboardingNotificationPayload(NotificationPayload):
step: OnboardingStep | None
class CopilotCompletionPayload(NotificationPayload):
session_id: str
status: Literal["completed", "failed"]

View File

@@ -418,8 +418,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def __init__(
self,
id: str = "",
@@ -472,8 +470,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
if self.webhook_config:
@@ -624,7 +620,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
is_graph_execution: bool = True,
**kwargs,
) -> tuple[bool, BlockInput]:
"""
@@ -653,7 +648,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
graph_version=graph_version,
block_name=self.name,
editable=True,
is_graph_execution=is_graph_execution,
)
if decision is None:

View File

@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
# Execute the code
execution = await sandbox.run_code( # type: ignore[attr-defined]
execution = await sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error

View File

@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Email sent successfully")],
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -1,3 +0,0 @@
def github_repo_path(repo_url: str) -> str:
"""Extract 'owner/repo' from a GitHub repository URL."""
return repo_url.replace("https://github.com/", "")

View File

@@ -1,408 +0,0 @@
import asyncio
from enum import StrEnum
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import parse_data_uri, resolve_media_content
from backend.util.type import MediaFileType
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListCommitsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to list commits from",
default="main",
)
per_page: int = SchemaField(
description="Number of commits to return (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class CommitItem(TypedDict):
sha: str
message: str
author: str
date: str
url: str
commit: CommitItem = SchemaField(
title="Commit", description="A commit with its details"
)
commits: list[CommitItem] = SchemaField(
description="List of commits with their details"
)
error: str = SchemaField(description="Error message if listing commits failed")
def __init__(self):
super().__init__(
id="8b13f579-d8b6-4dc2-a140-f770428805de",
description="This block lists commits on a branch in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListCommitsBlock.Input,
output_schema=GithubListCommitsBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"commits",
[
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
],
),
(
"commit",
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
},
),
],
test_mock={
"list_commits": lambda *args, **kwargs: [
{
"sha": "abc123",
"message": "Initial commit",
"author": "octocat",
"date": "2024-01-01T00:00:00Z",
"url": "https://github.com/owner/repo/commit/abc123",
}
]
},
)
@staticmethod
async def list_commits(
credentials: GithubCredentials,
repo_url: str,
branch: str,
per_page: int,
page: int,
) -> list[Output.CommitItem]:
api = get_api(credentials)
commits_url = repo_url + "/commits"
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
response = await api.get(commits_url, params=params)
data = response.json()
repo_path = github_repo_path(repo_url)
return [
GithubListCommitsBlock.Output.CommitItem(
sha=c["sha"],
message=c["commit"]["message"],
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
date=(c["commit"].get("author") or {}).get("date", ""),
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
)
for c in data
]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
commits = await self.list_commits(
credentials,
input_data.repo_url,
input_data.branch,
input_data.per_page,
input_data.page,
)
yield "commits", commits
for commit in commits:
yield "commit", commit
except Exception as e:
yield "error", str(e)
class FileOperation(StrEnum):
"""File operations for GithubMultiFileCommitBlock.
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
between creation and update — the blob is placed at the given path regardless
of whether a file already exists there).
DELETE removes a file from the tree.
"""
UPSERT = "upsert"
DELETE = "delete"
class FileOperationInput(TypedDict):
path: str
# MediaFileType is a str NewType — no runtime breakage for existing callers.
content: MediaFileType
operation: FileOperation
class GithubMultiFileCommitBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch to commit to",
placeholder="feature-branch",
)
commit_message: str = SchemaField(
description="Commit message",
placeholder="Add new feature",
)
files: list[FileOperationInput] = SchemaField(
description=(
"List of file operations. Each item has: "
"'path' (file path), 'content' (file content, ignored for delete), "
"'operation' (upsert/delete)"
),
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the new commit")
url: str = SchemaField(description="URL of the new commit")
error: str = SchemaField(description="Error message if the commit failed")
def __init__(self):
super().__init__(
id="389eee51-a95e-4230-9bed-92167a327802",
description=(
"This block creates a single commit with multiple file "
"upsert/delete operations using the Git Trees API."
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMultiFileCommitBlock.Input,
output_schema=GithubMultiFileCommitBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "Add files",
"files": [
{
"path": "src/new.py",
"content": "print('hello')",
"operation": "upsert",
},
{
"path": "src/old.py",
"content": "",
"operation": "delete",
},
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "newcommitsha"),
("url", "https://github.com/owner/repo/commit/newcommitsha"),
],
test_mock={
"multi_file_commit": lambda *args, **kwargs: (
"newcommitsha",
"https://github.com/owner/repo/commit/newcommitsha",
)
},
)
@staticmethod
async def multi_file_commit(
credentials: GithubCredentials,
repo_url: str,
branch: str,
commit_message: str,
files: list[FileOperationInput],
) -> tuple[str, str]:
api = get_api(credentials)
safe_branch = quote(branch, safe="")
# 1. Get the latest commit SHA for the branch
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
response = await api.get(ref_url)
ref_data = response.json()
latest_commit_sha = ref_data["object"]["sha"]
# 2. Get the tree SHA of the latest commit
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
response = await api.get(commit_url)
commit_data = response.json()
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": encoding},
)
return blob_response.json()["sha"]
tree_entries: list[dict] = []
upsert_files = []
for file_op in files:
path = file_op["path"]
operation = FileOperation(file_op.get("operation", "upsert"))
if operation == FileOperation.DELETE:
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": None, # null SHA = delete
}
)
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently. Data URIs (from store_media_file)
# are sent as base64 blobs to preserve binary content.
if upsert_files:
async def _make_blob(content: str) -> str:
parsed = parse_data_uri(content)
if parsed is not None:
_, b64_payload = parsed
return await _create_blob(b64_payload, encoding="base64")
return await _create_blob(content)
blob_shas = await asyncio.gather(
*[_make_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
{
"path": path,
"mode": "100644",
"type": "blob",
"sha": blob_sha,
}
)
# 4. Create a new tree
tree_url = repo_url + "/git/trees"
tree_response = await api.post(
tree_url,
json={"base_tree": base_tree_sha, "tree": tree_entries},
)
new_tree_sha = tree_response.json()["sha"]
# 5. Create a new commit
new_commit_url = repo_url + "/git/commits"
commit_response = await api.post(
new_commit_url,
json={
"message": commit_message,
"tree": new_tree_sha,
"parents": [latest_commit_sha],
},
)
new_commit_sha = commit_response.json()["sha"]
# 6. Update the branch reference
try:
await api.patch(
ref_url,
json={"sha": new_commit_sha},
)
except Exception as e:
raise RuntimeError(
f"Commit {new_commit_sha} was created but failed to update "
f"ref heads/{branch}: {e}. "
f"You can recover by manually updating the branch to {new_commit_sha}."
) from e
repo_path = github_repo_path(repo_url)
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
return new_commit_sha, commit_web_url
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
# Resolve media references (workspace://, data:, URLs) to data
# URIs so _make_blob can send binary content correctly.
resolved_files: list[FileOperationInput] = []
for file_op in input_data.files:
content = file_op.get("content", "")
operation = FileOperation(file_op.get("operation", "upsert"))
if operation != FileOperation.DELETE:
content = await resolve_media_content(
MediaFileType(content),
execution_context,
return_format="for_external_api",
)
resolved_files.append(
FileOperationInput(
path=file_op["path"],
content=MediaFileType(content),
operation=operation,
)
)
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
resolved_files,
)
yield "sha", sha
yield "url", url
except Exception as e:
yield "error", str(e)

View File

@@ -1,5 +1,4 @@
import re
from typing import Literal
from typing_extensions import TypedDict
@@ -21,8 +20,6 @@ from ._auth import (
GithubCredentialsInput,
)
MergeMethod = Literal["merge", "squash", "rebase"]
class GithubListPullRequestsBlock(Block):
class Input(BlockSchemaInput):
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
yield "reviewer", reviewer
class GithubMergePullRequestBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
pr_url: str = SchemaField(
description="URL of the GitHub pull request",
placeholder="https://github.com/owner/repo/pull/1",
)
merge_method: MergeMethod = SchemaField(
description="Merge method to use: merge, squash, or rebase",
default="merge",
)
commit_title: str = SchemaField(
description="Title for the merge commit (optional, used for merge and squash)",
default="",
)
commit_message: str = SchemaField(
description="Message for the merge commit (optional, used for merge and squash)",
default="",
)
class Output(BlockSchemaOutput):
sha: str = SchemaField(description="SHA of the merge commit")
merged: bool = SchemaField(description="Whether the PR was merged")
message: str = SchemaField(description="Merge status message")
error: str = SchemaField(description="Error message if the merge failed")
def __init__(self):
super().__init__(
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
description="This block merges a pull request using merge, squash, or rebase.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMergePullRequestBlock.Input,
output_schema=GithubMergePullRequestBlock.Output,
test_input={
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("sha", "abc123"),
("merged", True),
("message", "Pull Request successfully merged"),
],
test_mock={
"merge_pr": lambda *args, **kwargs: (
"abc123",
True,
"Pull Request successfully merged",
)
},
is_sensitive_action=True,
)
@staticmethod
async def merge_pr(
credentials: GithubCredentials,
pr_url: str,
merge_method: MergeMethod,
commit_title: str,
commit_message: str,
) -> tuple[str, bool, str]:
api = get_api(credentials)
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
data: dict[str, str] = {"merge_method": merge_method}
if commit_title:
data["commit_title"] = commit_title
if commit_message:
data["commit_message"] = commit_message
response = await api.put(merge_url, json=data)
result = response.json()
return result["sha"], result["merged"], result["message"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
sha, merged, message = await self.merge_pr(
credentials,
input_data.pr_url,
input_data.merge_method,
input_data.commit_title,
input_data.commit_message,
)
yield "sha", sha
yield "merged", merged
yield "message", message
except Exception as e:
yield "error", str(e)
def prepare_pr_api_url(pr_url: str, path: str) -> str:
# Pattern to capture the base repository URL and the pull request number
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
match = re.match(pattern, pr_url)
if not match:
return pr_url
scheme, base_url, pr_number = match.groups()
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
base_url, pr_number = match.groups()
return f"{base_url}/pulls/{pr_number}/{path}"

View File

@@ -1,3 +1,5 @@
import base64
from typing_extensions import TypedDict
from backend.blocks._base import (
@@ -17,7 +19,6 @@ from ._auth import (
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListTagsBlock(Block):
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
tags_url = repo_url + "/tags"
response = await api.get(tags_url)
data = response.json()
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
tags: list[GithubListTagsBlock.Output.TagItem] = [
{
"name": tag["name"],
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
yield "tag", tag
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(branches_url)
data = response.json()
repo_path = repo_url.replace("https://github.com/", "")
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
branches = await self.list_branches(
credentials,
input_data.repo_url,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
class GithubListDiscussionsBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
) -> list[Output.DiscussionItem]:
api = get_api(credentials)
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
repo_path = github_repo_path(repo_url)
repo_path = repo_url.replace("https://github.com/", "")
owner, repo = repo_path.split("/")
query = """
query($owner: String!, $repo: String!, $num: Int!) {
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
yield "release", release
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to master)",
placeholder="branch_name",
default="master",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "master",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{branch}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{file_path}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubCreateRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
def __init__(self):
super().__init__(
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
description="This block lists all users who have starred a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListStargazersBlock.Input,
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
yield "stargazers", stargazers
for stargazer in stargazers:
yield "stargazer", stargazer
class GithubGetRepositoryInfoBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
name: str = SchemaField(description="Repository name")
full_name: str = SchemaField(description="Full repository name (owner/repo)")
description: str = SchemaField(description="Repository description")
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
private: bool = SchemaField(description="Whether the repository is private")
html_url: str = SchemaField(description="Web URL of the repository")
clone_url: str = SchemaField(description="Git clone URL")
stars: int = SchemaField(description="Number of stars")
forks: int = SchemaField(description="Number of forks")
open_issues: int = SchemaField(description="Number of open issues")
error: str = SchemaField(
description="Error message if fetching repo info failed"
)
def __init__(self):
super().__init__(
id="59d4f241-968a-4040-95da-348ac5c5ce27",
description="This block retrieves metadata about a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryInfoBlock.Input,
output_schema=GithubGetRepositoryInfoBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("name", "repo"),
("full_name", "owner/repo"),
("description", "A test repo"),
("default_branch", "main"),
("private", False),
("html_url", "https://github.com/owner/repo"),
("clone_url", "https://github.com/owner/repo.git"),
("stars", 42),
("forks", 5),
("open_issues", 3),
],
test_mock={
"get_repo_info": lambda *args, **kwargs: {
"name": "repo",
"full_name": "owner/repo",
"description": "A test repo",
"default_branch": "main",
"private": False,
"html_url": "https://github.com/owner/repo",
"clone_url": "https://github.com/owner/repo.git",
"stargazers_count": 42,
"forks_count": 5,
"open_issues_count": 3,
}
},
)
@staticmethod
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
api = get_api(credentials)
response = await api.get(repo_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.get_repo_info(credentials, input_data.repo_url)
yield "name", data["name"]
yield "full_name", data["full_name"]
yield "description", data.get("description", "") or ""
yield "default_branch", data["default_branch"]
yield "private", data["private"]
yield "html_url", data["html_url"]
yield "clone_url", data["clone_url"]
yield "stars", data["stargazers_count"]
yield "forks", data["forks_count"]
yield "open_issues", data["open_issues_count"]
except Exception as e:
yield "error", str(e)
class GithubForkRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to fork",
placeholder="https://github.com/owner/repo",
)
organization: str = SchemaField(
description="Organization to fork into (leave empty to fork to your account)",
default="",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the forked repository")
clone_url: str = SchemaField(description="Git clone URL of the fork")
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
error: str = SchemaField(description="Error message if the fork failed")
def __init__(self):
super().__init__(
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
description="This block forks a GitHub repository to your account or an organization.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubForkRepositoryBlock.Input,
output_schema=GithubForkRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"organization": "",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/myuser/repo"),
("clone_url", "https://github.com/myuser/repo.git"),
("full_name", "myuser/repo"),
],
test_mock={
"fork_repo": lambda *args, **kwargs: (
"https://github.com/myuser/repo",
"https://github.com/myuser/repo.git",
"myuser/repo",
)
},
)
@staticmethod
async def fork_repo(
credentials: GithubCredentials,
repo_url: str,
organization: str,
) -> tuple[str, str, str]:
api = get_api(credentials)
forks_url = repo_url + "/forks"
data: dict[str, str] = {}
if organization:
data["organization"] = organization
response = await api.post(forks_url, json=data)
result = response.json()
return result["html_url"], result["clone_url"], result["full_name"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, clone_url, full_name = await self.fork_repo(
credentials,
input_data.repo_url,
input_data.organization,
)
yield "url", url
yield "clone_url", clone_url
yield "full_name", full_name
except Exception as e:
yield "error", str(e)
class GithubStarRepositoryBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository to star",
placeholder="https://github.com/owner/repo",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the star operation")
error: str = SchemaField(description="Error message if starring failed")
def __init__(self):
super().__init__(
id="bd700764-53e3-44dd-a969-d1854088458f",
description="This block stars a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubStarRepositoryBlock.Input,
output_schema=GithubStarRepositoryBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Repository starred successfully")],
test_mock={
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
},
)
@staticmethod
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
api = get_api(credentials, convert_urls=False)
repo_path = github_repo_path(repo_url)
owner, repo = repo_path.split("/")
await api.put(
f"https://api.github.com/user/starred/{owner}/{repo}",
headers={"Content-Length": "0"},
)
return "Repository starred successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.star_repo(credentials, input_data.repo_url)
yield "status", status
except Exception as e:
yield "error", str(e)

View File

@@ -1,452 +0,0 @@
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
from ._utils import github_repo_path
class GithubListBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
per_page: int = SchemaField(
description="Number of branches to return per page (max 100)",
default=30,
ge=1,
le=100,
)
page: int = SchemaField(
description="Page number for pagination",
default=1,
ge=1,
)
class Output(BlockSchemaOutput):
class BranchItem(TypedDict):
name: str
url: str
branch: BranchItem = SchemaField(
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing branches failed")
def __init__(self):
super().__init__(
id="74243e49-2bec-4916-8bf4-db43d44aead5",
description="This block lists all branches for a specified GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListBranchesBlock.Input,
output_schema=GithubListBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"per_page": 30,
"page": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
],
test_mock={
"list_branches": lambda *args, **kwargs: [
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
]
},
)
@staticmethod
async def list_branches(
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
) -> list[Output.BranchItem]:
api = get_api(credentials)
branches_url = repo_url + "/branches"
response = await api.get(
branches_url, params={"per_page": str(per_page), "page": str(page)}
)
data = response.json()
repo_path = github_repo_path(repo_url)
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
{
"name": branch["name"],
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
}
for branch in data
]
return branches
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
branches = await self.list_branches(
credentials,
input_data.repo_url,
input_data.per_page,
input_data.page,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
except Exception as e:
yield "error", str(e)
class GithubMakeBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
new_branch: str = SchemaField(
description="Name of the new branch",
placeholder="new_branch_name",
)
source_branch: str = SchemaField(
description="Name of the source branch",
placeholder="source_branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch creation operation")
error: str = SchemaField(
description="Error message if the branch creation failed"
)
def __init__(self):
super().__init__(
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
description="This block creates a new branch from a specified source branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubMakeBranchBlock.Input,
output_schema=GithubMakeBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"new_branch": "new_branch_name",
"source_branch": "source_branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch created successfully")],
test_mock={
"create_branch": lambda *args, **kwargs: "Branch created successfully"
},
)
@staticmethod
async def create_branch(
credentials: GithubCredentials,
repo_url: str,
new_branch: str,
source_branch: str,
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
response = await api.get(ref_url)
data = response.json()
sha = data["object"]["sha"]
# Create the new branch
new_ref_url = repo_url + "/git/refs"
data = {
"ref": f"refs/heads/{new_branch}",
"sha": sha,
}
response = await api.post(new_ref_url, json=data)
return "Branch created successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubDeleteBranchBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Name of the branch to delete",
placeholder="branch_name",
)
class Output(BlockSchemaOutput):
status: str = SchemaField(description="Status of the branch deletion operation")
error: str = SchemaField(
description="Error message if the branch deletion failed"
)
def __init__(self):
super().__init__(
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
description="This block deletes a specified branch.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubDeleteBranchBlock.Input,
output_schema=GithubDeleteBranchBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "branch_name",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Branch deleted successfully")],
test_mock={
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
},
is_sensitive_action=True,
)
@staticmethod
async def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
api = get_api(credentials)
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
await api.delete(ref_url)
return "Branch deleted successfully"
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = await self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
except Exception as e:
yield "error", str(e)
class GithubCompareBranchesBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
base: str = SchemaField(
description="Base branch or commit SHA",
placeholder="main",
)
head: str = SchemaField(
description="Head branch or commit SHA to compare against base",
placeholder="feature-branch",
)
class Output(BlockSchemaOutput):
class FileChange(TypedDict):
filename: str
status: str
additions: int
deletions: int
patch: str
status: str = SchemaField(
description="Comparison status: ahead, behind, diverged, or identical"
)
ahead_by: int = SchemaField(
description="Number of commits head is ahead of base"
)
behind_by: int = SchemaField(
description="Number of commits head is behind base"
)
total_commits: int = SchemaField(
description="Total number of commits in the comparison"
)
diff: str = SchemaField(description="Unified diff of all file changes")
file: FileChange = SchemaField(
title="Changed File", description="A changed file with its diff"
)
files: list[FileChange] = SchemaField(
description="List of changed files with their diffs"
)
error: str = SchemaField(description="Error message if comparison failed")
def __init__(self):
super().__init__(
id="2e4faa8c-6086-4546-ba77-172d1d560186",
description="This block compares two branches or commits in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCompareBranchesBlock.Input,
output_schema=GithubCompareBranchesBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"base": "main",
"head": "feature",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("status", "ahead"),
("ahead_by", 2),
("behind_by", 0),
("total_commits", 2),
("diff", "+++ b/file.py\n+new line"),
(
"files",
[
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
),
(
"file",
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
},
),
],
test_mock={
"compare_branches": lambda *args, **kwargs: {
"status": "ahead",
"ahead_by": 2,
"behind_by": 0,
"total_commits": 2,
"files": [
{
"filename": "file.py",
"status": "modified",
"additions": 1,
"deletions": 0,
"patch": "+new line",
}
],
}
},
)
@staticmethod
async def compare_branches(
credentials: GithubCredentials,
repo_url: str,
base: str,
head: str,
) -> dict:
api = get_api(credentials)
safe_base = quote(base, safe="")
safe_head = quote(head, safe="")
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
response = await api.get(compare_url)
return response.json()
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
data = await self.compare_branches(
credentials,
input_data.repo_url,
input_data.base,
input_data.head,
)
yield "status", data["status"]
yield "ahead_by", data["ahead_by"]
yield "behind_by", data["behind_by"]
yield "total_commits", data["total_commits"]
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
GithubCompareBranchesBlock.Output.FileChange(
filename=f["filename"],
status=f["status"],
additions=f["additions"],
deletions=f["deletions"],
patch=f.get("patch", ""),
)
for f in data.get("files", [])
]
# Build unified diff
diff_parts = []
for f in data.get("files", []):
patch = f.get("patch", "")
if patch:
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
yield "diff", "\n".join(diff_parts)
yield "files", files
for file in files:
yield "file", file
except Exception as e:
yield "error", str(e)

View File

@@ -1,720 +0,0 @@
import base64
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
class GithubReadFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file in the repository",
placeholder="path/to/file",
)
branch: str = SchemaField(
description="Branch to read from",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
text_content: str = SchemaField(
description="Content of the file (decoded as UTF-8 text)"
)
raw_content: str = SchemaField(
description="Raw base64-encoded content of the file"
)
size: int = SchemaField(description="The size of the file (in bytes)")
error: str = SchemaField(description="Error message if reading the file failed")
def __init__(self):
super().__init__(
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
description="This block reads the content of a specified file from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFileBlock.Input,
output_schema=GithubReadFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "path/to/file",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("raw_content", "RmlsZSBjb250ZW50"),
("text_content", "File content"),
("size", 13),
],
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
)
@staticmethod
async def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
api = get_api(credentials)
content_url = (
repo_url
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
)
response = await api.get(content_url)
data = response.json()
if isinstance(data, list):
# Multiple entries of different types exist at this path
if not (file := next((f for f in data if f["type"] == "file"), None)):
raise TypeError("Not a file")
data = file
if data["type"] != "file":
raise TypeError("Not a file")
return data["content"], data["size"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
content, size = await self.read_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.branch,
)
yield "raw_content", content
yield "text_content", base64.b64decode(content).decode("utf-8")
yield "size", size
except Exception as e:
yield "error", str(e)
class GithubReadFolderBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
folder_path: str = SchemaField(
description="Path to the folder in the repository",
placeholder="path/to/folder",
)
branch: str = SchemaField(
description="Branch name to read from (defaults to main)",
placeholder="branch_name",
default="main",
)
class Output(BlockSchemaOutput):
class DirEntry(TypedDict):
name: str
path: str
class FileEntry(TypedDict):
name: str
path: str
size: int
file: FileEntry = SchemaField(description="Files in the folder")
dir: DirEntry = SchemaField(description="Directories in the folder")
error: str = SchemaField(
description="Error message if reading the folder failed"
)
def __init__(self):
super().__init__(
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
description="This block reads the content of a specified folder from a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubReadFolderBlock.Input,
output_schema=GithubReadFolderBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"folder_path": "path/to/folder",
"branch": "main",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"file",
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
},
),
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
],
test_mock={
"read_folder": lambda *args, **kwargs: (
[
{
"name": "file1.txt",
"path": "path/to/folder/file1.txt",
"size": 1337,
}
],
[{"name": "dir2", "path": "path/to/folder/dir2"}],
)
},
)
@staticmethod
async def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
api = get_api(credentials)
contents_url = (
repo_url
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
)
response = await api.get(contents_url)
data = response.json()
if not isinstance(data, list):
raise TypeError("Not a folder")
files: list[GithubReadFolderBlock.Output.FileEntry] = [
GithubReadFolderBlock.Output.FileEntry(
name=entry["name"],
path=entry["path"],
size=entry["size"],
)
for entry in data
if entry["type"] == "file"
]
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
GithubReadFolderBlock.Output.DirEntry(
name=entry["name"],
path=entry["path"],
)
for entry in data
if entry["type"] == "dir"
]
return files, dirs
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
files, dirs = await self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
for file in files:
yield "file", file
for dir in dirs:
yield "dir", dir
except Exception as e:
yield "error", str(e)
class GithubCreateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path where the file should be created",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="Content to write to the file",
placeholder="File content here",
)
branch: str = SchemaField(
description="Branch where the file should be created",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Create new file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the created file")
sha: str = SchemaField(description="SHA of the commit")
error: str = SchemaField(
description="Error message if the file creation failed"
)
def __init__(self):
super().__init__(
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
description="This block creates a new file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateFileBlock.Input,
output_schema=GithubCreateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Test content",
"branch": "main",
"commit_message": "Create test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "abc123"),
],
test_mock={
"create_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"abc123",
)
},
)
@staticmethod
async def create_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.create_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubUpdateFileBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
file_path: str = SchemaField(
description="Path to the file to update",
placeholder="path/to/file.txt",
)
content: str = SchemaField(
description="New content for the file",
placeholder="Updated content here",
)
branch: str = SchemaField(
description="Branch containing the file",
default="main",
)
commit_message: str = SchemaField(
description="Message for the commit",
default="Update file",
)
class Output(BlockSchemaOutput):
url: str = SchemaField(description="URL of the updated file")
sha: str = SchemaField(description="SHA of the commit")
def __init__(self):
super().__init__(
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
description="This block updates an existing file in a GitHub repository.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubUpdateFileBlock.Input,
output_schema=GithubUpdateFileBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"file_path": "test/file.txt",
"content": "Updated content",
"branch": "main",
"commit_message": "Update test file",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
("sha", "def456"),
],
test_mock={
"update_file": lambda *args, **kwargs: (
"https://github.com/owner/repo/blob/main/test/file.txt",
"def456",
)
},
)
@staticmethod
async def update_file(
credentials: GithubCredentials,
repo_url: str,
file_path: str,
content: str,
branch: str,
commit_message: str,
) -> tuple[str, str]:
api = get_api(credentials)
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
params = {"ref": branch}
response = await api.get(contents_url, params=params)
data = response.json()
# Convert new content to base64
content_base64 = base64.b64encode(content.encode()).decode()
data = {
"message": commit_message,
"content": content_base64,
"sha": data["sha"],
"branch": branch,
}
response = await api.put(contents_url, json=data)
data = response.json()
return data["content"]["html_url"], data["commit"]["sha"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
url, sha = await self.update_file(
credentials,
input_data.repo_url,
input_data.file_path,
input_data.content,
input_data.branch,
input_data.commit_message,
)
yield "url", url
yield "sha", sha
except Exception as e:
yield "error", str(e)
class GithubSearchCodeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
query: str = SchemaField(
description="Search query (GitHub code search syntax)",
placeholder="className language:python",
)
repo: str = SchemaField(
description="Restrict search to a repository (owner/repo format, optional)",
default="",
placeholder="owner/repo",
)
per_page: int = SchemaField(
description="Number of results to return (max 100)",
default=30,
ge=1,
le=100,
)
class Output(BlockSchemaOutput):
class SearchResult(TypedDict):
name: str
path: str
repository: str
url: str
score: float
result: SearchResult = SchemaField(
title="Result", description="A code search result"
)
results: list[SearchResult] = SchemaField(
description="List of code search results"
)
total_count: int = SchemaField(description="Total number of matching results")
error: str = SchemaField(description="Error message if search failed")
def __init__(self):
super().__init__(
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
description="This block searches for code in GitHub repositories.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSearchCodeBlock.Input,
output_schema=GithubSearchCodeBlock.Output,
test_input={
"query": "addClass",
"repo": "owner/repo",
"per_page": 30,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("total_count", 1),
(
"results",
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
),
(
"result",
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
},
),
],
test_mock={
"search_code": lambda *args, **kwargs: (
1,
[
{
"name": "file.py",
"path": "src/file.py",
"repository": "owner/repo",
"url": "https://github.com/owner/repo/blob/main/src/file.py",
"score": 1.0,
}
],
)
},
)
@staticmethod
async def search_code(
credentials: GithubCredentials,
query: str,
repo: str,
per_page: int,
) -> tuple[int, list[Output.SearchResult]]:
api = get_api(credentials, convert_urls=False)
full_query = f"{query} repo:{repo}" if repo else query
params = {"q": full_query, "per_page": str(per_page)}
response = await api.get("https://api.github.com/search/code", params=params)
data = response.json()
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
GithubSearchCodeBlock.Output.SearchResult(
name=item["name"],
path=item["path"],
repository=item["repository"]["full_name"],
url=item["html_url"],
score=item["score"],
)
for item in data["items"]
]
return data["total_count"], results
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
total_count, results = await self.search_code(
credentials,
input_data.query,
input_data.repo,
input_data.per_page,
)
yield "total_count", total_count
yield "results", results
for result in results:
yield "result", result
except Exception as e:
yield "error", str(e)
class GithubGetRepositoryTreeBlock(Block):
class Input(BlockSchemaInput):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo_url: str = SchemaField(
description="URL of the GitHub repository",
placeholder="https://github.com/owner/repo",
)
branch: str = SchemaField(
description="Branch name to get the tree from",
default="main",
)
recursive: bool = SchemaField(
description="Whether to recursively list the entire tree",
default=True,
)
class Output(BlockSchemaOutput):
class TreeEntry(TypedDict):
path: str
type: str
size: int
sha: str
entry: TreeEntry = SchemaField(
title="Tree Entry", description="A file or directory in the tree"
)
entries: list[TreeEntry] = SchemaField(
description="List of all files and directories in the tree"
)
truncated: bool = SchemaField(
description="Whether the tree was truncated due to size"
)
error: str = SchemaField(description="Error message if getting tree failed")
def __init__(self):
super().__init__(
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
description="This block lists the entire file tree of a GitHub repository recursively.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetRepositoryTreeBlock.Input,
output_schema=GithubGetRepositoryTreeBlock.Output,
test_input={
"repo_url": "https://github.com/owner/repo",
"branch": "main",
"recursive": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("truncated", False),
(
"entries",
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
),
(
"entry",
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
},
),
],
test_mock={
"get_tree": lambda *args, **kwargs: (
False,
[
{
"path": "src/main.py",
"type": "blob",
"size": 1234,
"sha": "abc123",
}
],
)
},
)
@staticmethod
async def get_tree(
credentials: GithubCredentials,
repo_url: str,
branch: str,
recursive: bool,
) -> tuple[bool, list[Output.TreeEntry]]:
api = get_api(credentials)
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
params = {"recursive": "1"} if recursive else {}
response = await api.get(tree_url, params=params)
data = response.json()
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
GithubGetRepositoryTreeBlock.Output.TreeEntry(
path=item["path"],
type=item["type"],
size=item.get("size", 0),
sha=item["sha"],
)
for item in data["tree"]
]
return data.get("truncated", False), entries
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
truncated, entries = await self.get_tree(
credentials,
input_data.repo_url,
input_data.branch,
input_data.recursive,
)
yield "truncated", truncated
yield "entries", entries
for entry in entries:
yield "entry", entry
except Exception as e:
yield "error", str(e)

View File

@@ -1,125 +0,0 @@
import inspect
import pytest
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.data.execution import ExecutionContext
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
class TestPreparePrApiUrl:
def test_https_scheme_preserved(self):
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
assert result == "https://github.com/owner/repo/pulls/42/merge"
def test_http_scheme_preserved(self):
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
assert result == "http://github.com/owner/repo/pulls/1/files"
def test_no_scheme_defaults_to_https(self):
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
assert result == "https://github.com/owner/repo/pulls/5/merge"
def test_reviewers_path(self):
result = prepare_pr_api_url(
"https://github.com/owner/repo/pull/99", "requested_reviewers"
)
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
def test_invalid_url_returned_as_is(self):
url = "https://example.com/not-a-pr"
assert prepare_pr_api_url(url, "merge") == url
def test_empty_string(self):
assert prepare_pr_api_url("", "merge") == ""
# ── Error-path block tests ──
# When a block's run() yields ("error", msg), _execute() converts it to a
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
# which returns early on empty test_output).
def _mock_block(block, mocks: dict):
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
for name, mock_fn in mocks.items():
original = getattr(block, name)
if inspect.iscoroutinefunction(original):
async def async_mock(*args, _fn=mock_fn, **kwargs):
return _fn(*args, **kwargs)
setattr(block, name, async_mock)
else:
setattr(block, name, mock_fn)
def _raise(exc: Exception):
"""Helper that returns a callable which raises the given exception."""
def _raiser(*args, **kwargs):
raise exc
return _raiser
@pytest.mark.asyncio
async def test_merge_pr_error_path():
block = GithubMergePullRequestBlock()
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
input_data = {
"pr_url": "https://github.com/owner/repo/pull/1",
"merge_method": "squash",
"commit_title": "",
"commit_message": "",
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
pass
@pytest.mark.asyncio
async def test_multi_file_commit_error_path():
block = GithubMultiFileCommitBlock()
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
input_data = {
"repo_url": "https://github.com/owner/repo",
"branch": "feature",
"commit_message": "test",
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(
input_data,
credentials=TEST_CREDENTIALS,
execution_context=ExecutionContext(),
):
pass
# ── FileOperation enum tests ──
class TestFileOperation:
def test_upsert_value(self):
assert FileOperation.UPSERT == "upsert"
def test_delete_value(self):
assert FileOperation.DELETE == "delete"
def test_invalid_value_raises(self):
with pytest.raises(ValueError):
FileOperation("create")
def test_invalid_value_raises_typo(self):
with pytest.raises(ValueError):
FileOperation("upser")

View File

@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
h.ignore_links = False
h.ignore_images = True
return h.handle(html_content)
except Exception:
# Keep extraction resilient if html2text is unavailable or fails.
except ImportError:
# Fallback: return raw HTML if html2text is not available
return html_content
# Handle content stored as attachment

View File

@@ -67,7 +67,6 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
@@ -144,11 +143,10 @@ class HITLReviewHelper:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
if is_graph_execution:
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
@@ -170,7 +168,6 @@ class HITLReviewHelper:
graph_version: int,
block_name: str = "Block",
editable: bool = False,
is_graph_execution: bool = True,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
@@ -200,7 +197,6 @@ class HITLReviewHelper:
graph_version=graph_version,
block_name=block_name,
editable=editable,
is_graph_execution=is_graph_execution,
)
if review_result is None:

View File

@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
from backend.blocks.search import GetRequest
from backend.data.model import SchemaField
from backend.util.exceptions import BlockExecutionError
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
class SearchTheWebBlock(Block, GetRequest):
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
) -> BlockOutput:
if input_data.raw_content:
try:
parsed_url, _, _ = await validate_url_host(input_data.url)
parsed_url, _, _ = await validate_url(input_data.url, [])
url = parsed_url.geturl()
except ValueError as e:
yield "error", f"Invalid URL: {e}"

View File

@@ -31,14 +31,10 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.request import validate_url_host
from backend.util.settings import Settings
from backend.util.text import TextFormatter
settings = Settings()
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
@@ -140,31 +136,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
CODESTRAL = "mistralai/codestral-2508"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
@@ -172,11 +156,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
MICROSOFT_PHI_4 = "microsoft/phi-4"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_3 = "x-ai/grok-3"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
@@ -354,41 +336,17 @@ MODEL_METADATA = {
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 2.5 Pro",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Pro Preview",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3 Flash Preview",
"OpenRouter",
"Google",
1,
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
@@ -396,15 +354,6 @@ MODEL_METADATA = {
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65536,
"Gemini 3.1 Flash Lite Preview",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
@@ -426,78 +375,12 @@ MODEL_METADATA = {
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
"open_router",
262144,
None,
"Mistral Large 3 2512",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
"open_router",
131072,
None,
"Mistral Medium 3.1",
"OpenRouter",
"Mistral AI",
2,
),
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
"open_router",
131072,
131072,
"Mistral Small 3.2 24B",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.CODESTRAL: ModelMetadata(
"open_router",
256000,
None,
"Codestral 2508",
"OpenRouter",
"Mistral AI",
1,
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
),
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Translate 08.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
"open_router",
256000,
32768,
"Command A Reasoning 08.2025",
"OpenRouter",
"Cohere",
3,
),
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
"open_router",
128000,
8192,
"Command A Vision 07.2025",
"OpenRouter",
"Cohere",
2,
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
@@ -510,15 +393,6 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
"open_router",
128000,
8000,
"Sonar Reasoning Pro",
"OpenRouter",
"Perplexity",
2,
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
@@ -564,9 +438,6 @@ MODEL_METADATA = {
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
@@ -576,15 +447,6 @@ MODEL_METADATA = {
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_3: ModelMetadata(
"open_router",
131072,
131072,
"Grok 3",
"OpenRouter",
"xAI",
2,
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
@@ -942,11 +804,6 @@ async def llm_call(
if tools:
raise ValueError("Ollama does not support tools.")
# Validate user-provided Ollama host to prevent SSRF etc.
await validate_url_host(
ollama_host, trusted_hostnames=[settings.config.ollama_host]
)
client = ollama.AsyncClient(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
@@ -968,7 +825,7 @@ async def llm_call(
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url=OPENROUTER_BASE_URL,
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr, field_validator
from pydantic import SecretStr
from backend.blocks._base import (
Block,
@@ -13,7 +13,6 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.block import BlockInput
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -22,7 +21,6 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.clients import OPENROUTER_BASE_URL
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
@@ -36,20 +34,6 @@ class PerplexityModel(str, Enum):
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
if isinstance(value, PerplexityModel):
return value
try:
return PerplexityModel(value)
except ValueError:
logger.warning(
f"Invalid PerplexityModel '{value}', "
f"falling back to {PerplexityModel.SONAR.value}"
)
return PerplexityModel.SONAR
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
@@ -88,25 +72,6 @@ class PerplexityBlock(Block):
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
@field_validator("model", mode="before")
@classmethod
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
"""Fall back to SONAR if the model value is not a valid
PerplexityModel (e.g. an OpenAI model ID set by the agent
generator)."""
return _sanitize_perplexity_model(v)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
"""Sanitize the model field before JSON schema validation so that
invalid values are replaced with the default instead of raising a
BlockInputError."""
model_value = data.get("model")
if model_value is not None:
data["model"] = _sanitize_perplexity_model(model_value).value
return super().validate_data(data)
system_prompt: str = SchemaField(
title="System Prompt",
default="",
@@ -171,7 +136,7 @@ class PerplexityBlock(Block):
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url=OPENROUTER_BASE_URL,
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)

View File

@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
("post_id", "abc123"),
],
test_mock={"delete_post": lambda creds, post_id: True},
is_sensitive_action=True,
)
@staticmethod
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
("comment_id", "xyz789"),
],
test_mock={"delete_comment": lambda creds, comment_id: True},
is_sensitive_action=True,
)
@staticmethod

View File

@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
"_convert_to_color": lambda *args, **kwargs: "black",
},
is_sensitive_action=True,
)
async def run(

View File

@@ -1,81 +0,0 @@
"""Unit tests for PerplexityBlock model fallback behavior."""
import pytest
from backend.blocks.perplexity import (
TEST_CREDENTIALS_INPUT,
PerplexityBlock,
PerplexityModel,
)
def _make_input(**overrides) -> dict:
defaults = {
"prompt": "test query",
"credentials": TEST_CREDENTIALS_INPUT,
}
defaults.update(overrides)
return defaults
class TestPerplexityModelFallback:
"""Tests for fallback_invalid_model field_validator."""
def test_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
assert inp.model == PerplexityModel.SONAR
def test_another_invalid_model_falls_back_to_sonar(self):
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
assert inp.model == PerplexityModel.SONAR
def test_valid_model_string_is_kept(self):
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
assert inp.model == PerplexityModel.SONAR_PRO
def test_valid_enum_value_is_kept(self):
inp = PerplexityBlock.Input(
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
)
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
def test_default_model_when_omitted(self):
inp = PerplexityBlock.Input(**_make_input())
assert inp.model == PerplexityModel.SONAR
@pytest.mark.parametrize(
"model_value",
[
"perplexity/sonar",
"perplexity/sonar-pro",
"perplexity/sonar-deep-research",
],
)
def test_all_valid_models_accepted(self, model_value: str):
inp = PerplexityBlock.Input(**_make_input(model=model_value))
assert inp.model.value == model_value
class TestPerplexityValidateData:
"""Tests for validate_data which runs during block execution (before
Pydantic instantiation). Invalid models must be sanitized here so
JSON schema validation does not reject them."""
def test_invalid_model_sanitized_before_schema_validation(self):
data = _make_input(model="gpt-5.2-2025-12-11")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == PerplexityModel.SONAR.value
def test_valid_model_unchanged_by_validate_data(self):
data = _make_input(model="perplexity/sonar-pro")
error = PerplexityBlock.Input.validate_data(data)
assert error is None
assert data["model"] == "perplexity/sonar-pro"
def test_missing_model_uses_default(self):
data = _make_input() # no model key
error = PerplexityBlock.Input.validate_data(data)
assert error is None
inp = PerplexityBlock.Input(**data)
assert inp.model == PerplexityModel.SONAR

View File

@@ -40,7 +40,7 @@ from backend.copilot.response_model import (
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
_get_openai_client,
client,
config,
)
from backend.copilot.tools import execute_tool, get_available_tools
@@ -89,7 +89,7 @@ async def _compress_session_messages(
result = await compress_context(
messages=messages_dict,
model=config.model,
client=_get_openai_client(),
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
@@ -235,7 +235,7 @@ async def stream_chat_completion_baseline(
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""

View File

@@ -1,13 +1,10 @@
"""Configuration management for chat system."""
import os
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -22,7 +19,7 @@ class ChatConfig(BaseSettings):
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default=OPENROUTER_BASE_URL,
default="https://openrouter.ai/api/v1",
description="Base URL for API (e.g., for OpenRouter)",
)
@@ -115,37 +112,9 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
)
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
default="pause",
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
)
@property
def e2b_active(self) -> bool:
"""True when E2B is enabled and the API key is present.
Single source of truth for "should we use E2B right now?".
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
separately at call sites.
"""
return self.use_e2b_sandbox and bool(self.e2b_api_key)
@property
def active_e2b_api_key(self) -> str | None:
"""Return the E2B API key when E2B is enabled and configured, else None.
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
Use in callers::
if api_key := config.active_e2b_api_key:
# E2B is active; api_key is narrowed to str
"""
return self.e2b_api_key if self.e2b_active else None
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
@@ -195,7 +164,7 @@ class ChatConfig(BaseSettings):
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = OPENROUTER_BASE_URL
v = "https://openrouter.ai/api/v1"
return v
@field_validator("use_claude_agent_sdk", mode="before")

View File

@@ -1,38 +0,0 @@
"""Unit tests for ChatConfig."""
import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
_E2B_ENV_VARS = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
)
@pytest.fixture(autouse=True)
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
for var in _E2B_ENV_VARS:
monkeypatch.delenv(var, raising=False)
class TestE2BActive:
"""Tests for the e2b_active property — single source of truth for E2B usage."""
def test_both_enabled_and_key_present_returns_true(self):
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
assert cfg.e2b_active is True
def test_enabled_but_missing_key_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
assert cfg.e2b_active is False
def test_disabled_returns_false(self):
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
assert cfg.e2b_active is False

View File

@@ -6,32 +6,6 @@
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Prefix for all synthetic IDs generated by CoPilot block execution.
# Used to distinguish CoPilot-generated records from real graph execution records
# in PendingHumanReview and other tables.
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
# Separator used in synthetic node_exec_id to encode node_id.
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"
def is_copilot_synthetic_id(id_value: str) -> bool:
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
"""Extract node_id from a synthetic node_exec_id.
Format: "{node_id}:{random_hex}" → returns "{node_id}".
"""
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]

View File

@@ -1,128 +0,0 @@
"""Shared execution context for copilot SDK tool handlers.
All context variables and their accessors live here so that
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
without creating circular dependencies.
"""
import os
import re
from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
# validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
) -> None:
"""Set per-turn context variables used by file-resolution tool handlers."""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Return the current (user_id, session) pair for the active request."""
return _current_user_id.get(), _current_session.get()
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current session, or None if not active."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK working directory for the current session (empty string if unset)."""
return _current_sdk_cwd.get()
E2B_WORKDIR = "/home/user"
def resolve_sandbox_path(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped :class:`WorkspaceManager`.
Placed here (rather than in ``tools/workspace_files``) so that modules
like ``sdk/file_ref`` can import it without triggering the heavy
``tools/__init__`` import chain.
"""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -1,163 +0,0 @@
"""Tests for context.py — execution context variables and path helpers."""
from __future__ import annotations
import os
import tempfile
from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
set_execution_context,
)
def _make_session() -> MagicMock:
s = MagicMock()
s.session_id = "test-session"
return s
# ---------------------------------------------------------------------------
# Context variable getters
# ---------------------------------------------------------------------------
def test_get_execution_context_defaults():
"""get_execution_context returns (None, session) when user_id is not set."""
set_execution_context(None, _make_session())
user_id, session = get_execution_context()
assert user_id is None
assert session is not None
def test_set_and_get_execution_context():
"""set_execution_context stores user_id and session."""
mock_session = _make_session()
set_execution_context("user-abc", mock_session)
user_id, session = get_execution_context()
assert user_id == "user-abc"
assert session is mock_session
def test_get_current_sandbox_none_by_default():
"""get_current_sandbox returns None when no sandbox is set."""
set_execution_context("u1", _make_session(), sandbox=None)
assert get_current_sandbox() is None
def test_get_current_sandbox_returns_set_value():
"""get_current_sandbox returns the sandbox set via set_execution_context."""
mock_sandbox = MagicMock()
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
assert get_current_sandbox() is mock_sandbox
def test_get_sdk_cwd_empty_when_not_set():
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
set_execution_context("u1", _make_session(), sdk_cwd=None)
assert get_sdk_cwd() == ""
def test_get_sdk_cwd_returns_set_value():
"""get_sdk_cwd returns the value set via set_execution_context."""
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
assert get_sdk_cwd() == "/tmp/copilot-test"
# ---------------------------------------------------------------------------
# is_allowed_local_path
# ---------------------------------------------------------------------------
def test_is_allowed_local_path_empty():
assert not is_allowed_local_path("")
def test_is_allowed_local_path_inside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
path = os.path.join(cwd, "file.txt")
assert is_allowed_local_path(path, cwd)
def test_is_allowed_local_path_sdk_cwd_itself():
with tempfile.TemporaryDirectory() as cwd:
assert is_allowed_local_path(cwd, cwd)
def test_is_allowed_local_path_outside_sdk_cwd():
with tempfile.TemporaryDirectory() as cwd:
assert not is_allowed_local_path("/etc/passwd", cwd)
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
"""Without sdk_cwd or project_dir, all paths are rejected."""
_current_project_dir.set("")
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
assert is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------
def test_resolve_sandbox_path_absolute_valid():
assert (
resolve_sandbox_path("/home/user/project/main.py")
== "/home/user/project/main.py"
)
def test_resolve_sandbox_path_relative():
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
def test_resolve_sandbox_path_workdir_itself():
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_resolve_sandbox_path_normalizes_dots():
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
def test_resolve_sandbox_path_escape_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_resolve_sandbox_path_absolute_outside_raises():
with pytest.raises(ValueError, match="/home/user"):
resolve_sandbox_path("/etc/passwd")

View File

@@ -1,138 +0,0 @@
"""Scheduler job to generate LLM-optimized block descriptions.
Runs periodically to rewrite block descriptions into concise, actionable
summaries that help the copilot LLM pick the right blocks during agent
generation.
"""
import asyncio
import logging
from backend.blocks import get_blocks
from backend.util.clients import get_database_manager_client, get_openai_client
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a technical writer for an automation platform. "
"Rewrite the following block description to be concise (under 50 words), "
"informative, and actionable. Focus on what the block does and when to "
"use it. Output ONLY the rewritten description, nothing else. "
"Do not use markdown formatting."
)
# Rate-limit delay between sequential LLM calls (seconds)
_RATE_LIMIT_DELAY = 0.5
# Maximum tokens for optimized description generation
_MAX_DESCRIPTION_TOKENS = 150
# Model for generating optimized descriptions (fast, cheap)
_MODEL = "gpt-4o-mini"
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
"""Call the shared OpenAI client to rewrite each block description."""
client = get_openai_client()
if client is None:
logger.error(
"No OpenAI client configured, skipping block description optimization"
)
return {}
results: dict[str, str] = {}
for block in blocks:
block_id = block["id"]
block_name = block["name"]
description = block["description"]
try:
response = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"Block name: {block_name}\nDescription: {description}",
},
],
max_tokens=_MAX_DESCRIPTION_TOKENS,
)
optimized = (response.choices[0].message.content or "").strip()
if optimized:
results[block_id] = optimized
logger.debug("Optimized description for %s", block_name)
else:
logger.warning("Empty response for block %s", block_name)
except Exception:
logger.warning(
"Failed to optimize description for %s", block_name, exc_info=True
)
await asyncio.sleep(_RATE_LIMIT_DELAY)
return results
def optimize_block_descriptions() -> dict[str, int]:
"""Generate optimized descriptions for blocks that don't have one yet.
Uses the shared OpenAI client to rewrite block descriptions into concise
summaries suitable for agent generation prompts.
Returns:
Dict with counts: processed, success, failed, skipped.
"""
db_client = get_database_manager_client()
blocks = db_client.get_blocks_needing_optimization()
if not blocks:
logger.info("All blocks already have optimized descriptions")
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
non_empty = [b for b in blocks if b.get("description", "").strip()]
skipped = len(blocks) - len(non_empty)
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
stats = {
"processed": len(non_empty),
"success": len(new_descriptions),
"failed": len(non_empty) - len(new_descriptions),
"skipped": skipped,
}
logger.info(
"Block description optimization complete: "
"%d/%d succeeded, %d failed, %d skipped",
stats["success"],
stats["processed"],
stats["failed"],
stats["skipped"],
)
if new_descriptions:
for block_id, optimized in new_descriptions.items():
db_client.update_block_optimized_description(block_id, optimized)
# Update in-memory descriptions first so the cache rebuilds with fresh data.
try:
block_classes = get_blocks()
for block_id, optimized in new_descriptions.items():
if block_id in block_classes:
block_classes[block_id]._optimized_description = optimized
logger.info(
"Updated %d in-memory block descriptions", len(new_descriptions)
)
except Exception:
logger.warning(
"Could not update in-memory block descriptions", exc_info=True
)
from backend.copilot.tools.agent_generator.blocks import (
reset_block_caches, # local to avoid circular import
)
reset_block_caches()
return stats

View File

@@ -1,91 +0,0 @@
"""Unit tests for optimize_blocks._optimize_descriptions."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
def _make_client_response(text: str) -> MagicMock:
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
choice = MagicMock()
choice.message.content = text
response = MagicMock()
response.choices = [choice]
return response
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
class TestOptimizeDescriptions:
"""Tests for _optimize_descriptions async function."""
def test_returns_empty_when_no_client(self):
with patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
):
result = _run(
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
)
assert result == {}
def test_success_single_block(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("Short desc.")
)
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {"b1": "Short desc."}
client.chat.completions.create.assert_called_once()
def test_skips_block_on_exception(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch(
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
),
):
result = _run(_optimize_descriptions(blocks))
assert result == {}
def test_sleeps_between_blocks(self):
client = MagicMock()
client.chat.completions.create = AsyncMock(
return_value=_make_client_response("desc")
)
blocks = [
{"id": "b1", "name": "B1", "description": "d1"},
{"id": "b2", "name": "B2", "description": "d2"},
]
sleep_mock = AsyncMock()
with (
patch(
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
),
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
):
_run(_optimize_descriptions(blocks))
assert sleep_mock.call_count == 2
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)

View File

@@ -26,70 +26,6 @@ your message as a Markdown link or image:
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Passing file content to tools — @@agptfile: references
Instead of copying large file contents into a tool argument, pass a file
reference and the platform will load the content for you.
Syntax: `@@agptfile:<uri>[<start>-<end>]`
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
- `workspace://<file_id>` — workspace file by ID
- `workspace:///<path>` — workspace file by virtual path
- `/absolute/local/path` — ephemeral or sdk_cwd file
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
- URIs that do not start with `workspace://` or `/` are **not** expanded.
Examples:
```
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.py
```
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Structured data**: When the **entire** argument value is a single file
reference (no surrounding text), the platform automatically parses the file
content based on its extension or MIME type. Supported formats: JSON, JSONL,
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
the rows will be parsed into `list[list[str]]` automatically. If the format is
unrecognised or parsing fails, the content is returned as a plain string.
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
**Type coercion**: The platform also coerces expanded values to match the
block's expected input types. For example, if a block expects `list[list[str]]`
and the expanded value is a JSON string, it will be parsed into the correct type.
### Media file inputs (format: "file")
Some block inputs accept media files — their schema shows `"format": "file"`.
These fields accept:
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
for large files (images, videos, PDFs). The platform passes the reference
directly to the block without reading the content into memory.
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
small files only.
When a block input has `format: "file"`, **pass the `workspace://` URI
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
payloads in tool arguments and preserves binary content (images, videos)
that would be corrupted by text encoding.
Example — committing an image file to GitHub:
```json
{
"files": [{
"path": "docs/hero.png",
"content": "workspace://abc123#image/png",
"operation": "upsert"
}]
}
```
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.

View File

@@ -3,45 +3,12 @@
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
circular import cycle::
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
function scope without a larger refactor (moving tool-name registration
to a separate lightweight module). The lazy-import pattern here is the
least invasive way to break the cycle while keeping module-level constants
intact.
"""
from typing import Any
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
# pair so new exports can be added without touching __getattr__ itself.
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
}
def __getattr__(name: str) -> Any:
entry = _LAZY_IMPORTS.get(name)
if entry is not None:
module_path, attr = entry
import importlib
module = importlib.import_module(module_path, package=__name__)
value = getattr(module, attr)
globals()[name] = value
return value
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,155 +0,0 @@
## Agent Generation Guide
You can create, edit, and customize agents directly. You ARE the brain —
generate the agent JSON yourself using block schemas, then validate and save.
### Workflow for Creating/Editing Agents
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
2. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
3. **Generate JSON**: Build the agent JSON using block schemas:
- Use block IDs from step 1 as `block_id` in nodes
- Wire outputs to inputs using links
- Set design-time config in `input_default`
- Use `AgentInputBlock` for values the user provides at runtime
4. **Write to workspace**: Save the JSON to a workspace file so the user
can review it: `write_workspace_file(filename="agent.json", content=...)`
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
for errors
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
or fix manually based on the error descriptions. Iterate until valid.
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
the final `agent_json`
### Agent JSON Structure
```json
{
"id": "<UUID v4>", // auto-generated if omitted
"version": 1,
"is_active": true,
"name": "Agent Name",
"description": "What the agent does",
"nodes": [
{
"id": "<UUID v4>",
"block_id": "<block UUID from find_block>",
"input_default": {
"field_name": "design-time value"
},
"metadata": {
"position": {"x": 0, "y": 0},
"customized_name": "Optional display name"
}
}
],
"links": [
{
"id": "<UUID v4>",
"source_id": "<source node UUID>",
"source_name": "output_field_name",
"sink_id": "<sink node UUID>",
"sink_name": "input_field_name",
"is_static": false
}
]
}
```
### REQUIRED: AgentInputBlock and AgentOutputBlock
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
These define the agent's interface — what it accepts and what it produces.
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
- Defines a user-facing input field on the agent
- Required `input_default` fields: `name` (str), `value` (default: null)
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
- Output: `result` — the user-provided value at runtime
- Create one AgentInputBlock per distinct input the agent needs
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
- Defines a user-facing output displayed after the agent runs
- Required `input_default` fields: `name` (str)
- The `value` input should be linked from another block's output
- Optional: `title`, `description`, `format` (Jinja2 template)
- Create one AgentOutputBlock per distinct result to show the user
Without these blocks, the agent has no interface and the user cannot provide
inputs or see outputs. NEVER skip them.
### Key Rules
- **Name & description**: Include `name` and `description` in the agent JSON
when creating a new agent, or when editing and the agent's purpose changed.
Without these the agent gets a generic default name.
- **Design-time vs runtime**: `input_default` = values known at build time.
For user-provided values, create an `AgentInputBlock` node and link its
output to the consuming block's input.
- **Credentials**: Do NOT require credentials upfront. Users configure
credentials later in the platform UI after the agent is saved.
- **Node spacing**: Position nodes with at least 800 X-units between them.
- **Nested properties**: Use `parentField_#_childField` notation in link
sink_name/source_name to access nested object fields.
- **is_static links**: Set `is_static: true` when the link carries a
design-time constant (matches a field in inputSchema with a default).
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
literal braces in prompt strings — single `{` and `}` are for
template variables.
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
`graph_version` in input_default, and wire inputs/outputs to match
the sub-agent's schema.
### Using Sub-Agents (AgentExecutorBlock)
To compose agents using other agents as sub-agents:
1. Call `find_library_agent` to find the sub-agent — the response includes
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
3. Set `input_default`:
- `graph_id`: from the library agent's `graph_id`
- `graph_version`: from the library agent's `graph_version`
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
- `user_id`: leave as `""` (filled at runtime)
- `inputs`: `{}` (populated by links at runtime)
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
property names (e.g., if input_schema has a `"url"` property, use
`"url"` as the sink_name)
5. Wire outputs: link from source names matching the sub-agent's
`output_schema` property names
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
the library agent IDs used, so the fixer can validate schemas
### Using MCP Tools (MCPToolBlock)
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
3. Set `input_default`:
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
- `selected_tool`: the tool name on that server
- `tool_input_schema`: JSON Schema for the tool's inputs
- `tool_arguments`: `{}` (populated by links or hardcoded values)
4. The block requires MCP credentials — the user configures these in the
platform UI after the agent is saved
5. Wire inputs using the tool argument field name directly as the sink_name
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
automatically collects top-level fields matching tool_input_schema into
tool_arguments.
6. Output: `result` (the tool's return value) and `error` (error message)
### Example: Simple AI Text Processor
A minimal agent with input, processing, and output:
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
input_default: {"name": "user_text", "title": "Text to process"},
output: "result")
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
input_default: {"name": "summary", "title": "Summary"},
input: "value" linked from Node 2's output)

View File

@@ -11,7 +11,7 @@ persistence, and the ``CompactionTracker`` state machine.
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from collections.abc import Callable
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
@@ -27,19 +27,6 @@ from ..response_model import (
logger = logging.getLogger(__name__)
@dataclass
class CompactionResult:
"""Result of emit_end_if_ready — bundles events with compaction metadata.
Eliminates the need for separate ``compaction_just_ended`` checks,
preventing TOCTOU races between the emit call and the flag read.
"""
events: list[StreamBaseResponse] = field(default_factory=list)
just_ended: bool = False
transcript_path: str = ""
# ---------------------------------------------------------------------------
# Event builders (private — use CompactionTracker or compaction_events)
# ---------------------------------------------------------------------------
@@ -190,22 +177,11 @@ class CompactionTracker:
self._start_emitted = False
self._done = False
self._tool_call_id = ""
self._transcript_path: str = ""
def on_compact(self, transcript_path: str = "") -> None:
"""Callback for the PreCompact hook. Stores transcript_path."""
if (
self._transcript_path
and transcript_path
and self._transcript_path != transcript_path
):
logger.warning(
"[Compaction] Overwriting transcript_path %s -> %s",
self._transcript_path,
transcript_path,
)
self._transcript_path = transcript_path
self._compact_start.set()
@property
def on_compact(self) -> Callable[[], None]:
"""Callback for the PreCompact hook."""
return self._compact_start.set
# ------------------------------------------------------------------
# Pre-query compaction
@@ -225,7 +201,6 @@ class CompactionTracker:
self._done = False
self._start_emitted = False
self._tool_call_id = ""
self._transcript_path = ""
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
"""If the PreCompact hook fired, emit start events (spinning tool)."""
@@ -236,20 +211,15 @@ class CompactionTracker:
return _start_events(self._tool_call_id)
return []
async def emit_end_if_ready(self, session: ChatSession) -> CompactionResult:
"""If compaction is in progress, emit end events and persist.
Returns a ``CompactionResult`` with ``just_ended=True`` and the
captured ``transcript_path`` when a compaction cycle completes.
This avoids a separate flag check (TOCTOU-safe).
"""
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
"""If compaction is in progress, emit end events and persist."""
# Yield so pending hook tasks can set compact_start
await asyncio.sleep(0)
if self._done:
return CompactionResult()
return []
if not self._start_emitted and not self._compact_start.is_set():
return CompactionResult()
return []
if self._start_emitted:
# Close the open spinner
@@ -262,12 +232,8 @@ class CompactionTracker:
COMPACTION_DONE_MSG, tool_call_id=persist_id
)
transcript_path = self._transcript_path
self._compact_start.clear()
self._start_emitted = False
self._done = True
self._transcript_path = ""
_persist(session, persist_id, COMPACTION_DONE_MSG)
return CompactionResult(
events=done_events, just_ended=True, transcript_path=transcript_path
)
return done_events

View File

@@ -195,11 +195,10 @@ class TestCompactionTracker:
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is True
assert len(result.events) == 2
assert isinstance(result.events[0], StreamToolOutputAvailable)
assert isinstance(result.events[1], StreamFinishStep)
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 2
assert isinstance(evts[0], StreamToolOutputAvailable)
assert isinstance(evts[1], StreamFinishStep)
# Should persist
assert len(session.messages) == 2
@@ -211,32 +210,28 @@ class TestCompactionTracker:
session = _make_session()
tracker.on_compact()
# Don't call emit_start_if_ready
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is True
assert len(result.events) == 5 # Full self-contained event
assert isinstance(result.events[0], StreamStartStep)
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 5 # Full self-contained event
assert isinstance(evts[0], StreamStartStep)
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_no_op_when_no_new_compaction(self):
async def test_emit_end_no_op_when_done(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
result1 = await tracker.emit_end_if_ready(session)
assert result1.just_ended is True
# Second call should be no-op (no new on_compact)
result2 = await tracker.emit_end_if_ready(session)
assert result2.just_ended is False
assert result2.events == []
await tracker.emit_end_if_ready(session)
# Second call should be no-op
evts = await tracker.emit_end_if_ready(session)
assert evts == []
@pytest.mark.asyncio
async def test_emit_end_no_op_when_nothing_happened(self):
tracker = CompactionTracker()
session = _make_session()
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is False
assert result.events == []
evts = await tracker.emit_end_if_ready(session)
assert evts == []
def test_emit_pre_query(self):
tracker = CompactionTracker()
@@ -251,29 +246,20 @@ class TestCompactionTracker:
tracker._done = True
tracker._start_emitted = True
tracker._tool_call_id = "old"
tracker._transcript_path = "/some/path"
tracker.reset_for_query()
assert tracker._done is False
assert tracker._start_emitted is False
assert tracker._tool_call_id == ""
assert tracker._transcript_path == ""
@pytest.mark.asyncio
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
"""After pre-query compaction, SDK compaction is blocked until
reset_for_query is called."""
async def test_pre_query_blocks_sdk_compaction(self):
"""After pre-query compaction, SDK compaction events are suppressed."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.on_compact()
# _done is True so emit_start_if_ready is blocked
evts = tracker.emit_start_if_ready()
assert evts == []
# Reset clears _done, allowing subsequent compaction
tracker.reset_for_query()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3
assert evts == [] # _done blocks it
@pytest.mark.asyncio
async def test_reset_allows_new_compaction(self):
@@ -293,9 +279,9 @@ class TestCompactionTracker:
session = _make_session()
tracker.on_compact()
start_evts = tracker.emit_start_if_ready()
result = await tracker.emit_end_if_ready(session)
end_evts = await tracker.emit_end_if_ready(session)
start_evt = start_evts[1]
end_evt = result.events[0]
end_evt = end_evts[0]
assert isinstance(start_evt, StreamToolInputStart)
assert isinstance(end_evt, StreamToolOutputAvailable)
assert start_evt.toolCallId == end_evt.toolCallId
@@ -303,105 +289,3 @@ class TestCompactionTracker:
tool_calls = session.messages[0].tool_calls
assert tool_calls is not None
assert tool_calls[0]["id"] == start_evt.toolCallId
@pytest.mark.asyncio
async def test_multiple_compactions_within_query(self):
"""Two mid-stream compactions within a single query both trigger."""
tracker = CompactionTracker()
session = _make_session()
# First compaction cycle
tracker.on_compact("/path/1")
tracker.emit_start_if_ready()
result1 = await tracker.emit_end_if_ready(session)
assert result1.just_ended is True
assert len(result1.events) == 2
assert result1.transcript_path == "/path/1"
# Second compaction cycle (should NOT be blocked — _done resets
# because emit_end_if_ready sets it True, but the next on_compact
# + emit_start_if_ready checks !_done which IS True now.
# So we need reset_for_query between queries, but within a single
# query multiple compactions work because _done blocks emit_start
# until the next message arrives, at which point emit_end detects it)
#
# Actually: _done=True blocks emit_start_if_ready, so we need
# the stream loop to reset. In practice service.py doesn't call
# reset between compactions within the same query — let's verify
# the actual behavior.
tracker.on_compact("/path/2")
# _done is True from first compaction, so start is blocked
start_evts = tracker.emit_start_if_ready()
assert start_evts == []
# But emit_end returns no-op because _done is True
result2 = await tracker.emit_end_if_ready(session)
assert result2.just_ended is False
@pytest.mark.asyncio
async def test_multiple_compactions_with_intervening_message(self):
"""Multiple compactions work when the stream loop processes messages between them.
In the real service.py flow:
1. PreCompact fires → on_compact()
2. emit_start shows spinner
3. Next message arrives → emit_end completes compaction (_done=True)
4. Stream continues processing messages...
5. If a second PreCompact fires, _done=True blocks emit_start
6. But the next message triggers emit_end, which sees _done=True → no-op
7. The stream loop needs to detect this and handle accordingly
The actual flow for multiple compactions within a query requires
_done to be cleared between them. The service.py code uses
CompactionResult.just_ended to trigger replace_entries, and _done
stays True until reset_for_query.
"""
tracker = CompactionTracker()
session = _make_session()
# First compaction
tracker.on_compact("/path/1")
tracker.emit_start_if_ready()
result1 = await tracker.emit_end_if_ready(session)
assert result1.just_ended is True
assert result1.transcript_path == "/path/1"
# Simulate reset between queries
tracker.reset_for_query()
# Second compaction in new query
tracker.on_compact("/path/2")
start_evts = tracker.emit_start_if_ready()
assert len(start_evts) == 3
result2 = await tracker.emit_end_if_ready(session)
assert result2.just_ended is True
assert result2.transcript_path == "/path/2"
def test_on_compact_stores_transcript_path(self):
tracker = CompactionTracker()
tracker.on_compact("/some/path.jsonl")
assert tracker._transcript_path == "/some/path.jsonl"
@pytest.mark.asyncio
async def test_emit_end_returns_transcript_path(self):
"""CompactionResult includes the transcript_path from on_compact."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact("/my/session.jsonl")
tracker.emit_start_if_ready()
result = await tracker.emit_end_if_ready(session)
assert result.just_ended is True
assert result.transcript_path == "/my/session.jsonl"
# transcript_path is cleared after emit_end
assert tracker._transcript_path == ""
@pytest.mark.asyncio
async def test_emit_end_clears_transcript_path(self):
"""After emit_end, _transcript_path is reset so it doesn't leak to
subsequent non-compaction emit_end calls."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact("/first/path.jsonl")
tracker.emit_start_if_ready()
await tracker.emit_end_if_ready(session)
# After compaction, _transcript_path is cleared
assert tracker._transcript_path == ""

View File

@@ -8,6 +8,8 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
@@ -15,23 +17,36 @@ import os
import shlex
from typing import Any, Callable
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
logger = logging.getLogger(__name__)
def _get_sandbox():
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
return is_allowed_local_path(path, get_sdk_cwd())
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
@@ -48,7 +63,7 @@ def _get_sandbox_and_path(
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = resolve_sandbox_path(file_path)
remote = _resolve_remote(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
@@ -58,7 +73,6 @@ def _get_sandbox_and_path(
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
@@ -90,7 +104,6 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
"""Write content to a sandbox file, creating parent directories as needed."""
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
@@ -114,7 +127,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
"""Replace a substring in a sandbox file, with optional replace-all support."""
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
@@ -160,7 +172,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
"""Find files matching a name pattern inside the sandbox using ``find``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
@@ -172,7 +183,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -187,7 +198,6 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
@@ -200,7 +210,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
@@ -228,7 +238,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded, encoding="utf-8", errors="replace") as fh:
with open(expanded) as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)

View File

@@ -7,60 +7,59 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
# _resolve_remote — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveSandboxPath:
class TestResolveRemote:
def test_relative_path_resolved(self):
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert resolve_sandbox_path("/home/user") == "/home/user"
assert _resolve_remote("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
assert _resolve_remote("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("../../etc/passwd")
_resolve_remote("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
_resolve_remote("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/etc/passwd")
_resolve_remote("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/")
_resolve_remote("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/other/file.txt")
_resolve_remote("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert resolve_sandbox_path("src/") == "/home/user/src"
assert _resolve_remote("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------

View File

@@ -1,531 +0,0 @@
"""End-to-end compaction flow test.
Simulates the full service.py compaction lifecycle using real-format
JSONL session files — no SDK subprocess needed. Exercises:
1. TranscriptBuilder loads a "downloaded" transcript
2. User query appended, assistant response streamed
3. PreCompact hook fires → CompactionTracker.on_compact()
4. Next message → emit_start_if_ready() yields spinner events
5. Message after that → emit_end_if_ready() returns CompactionResult
6. read_compacted_entries() reads the CLI session file
7. TranscriptBuilder.replace_entries() syncs state
8. More messages appended post-compaction
9. to_jsonl() exports full state for upload
10. Fresh builder loads the export — roundtrip verified
"""
import asyncio
from backend.copilot.model import ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import CompactionTracker
from backend.copilot.sdk.transcript import (
read_compacted_entries,
strip_progress_entries,
)
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
from backend.util import json
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.run(coro)
# ---------------------------------------------------------------------------
# Fixtures: realistic CLI session file content
# ---------------------------------------------------------------------------
# Pre-compaction conversation
USER_1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "What files are in this project?"},
}
ASST_1_THINKING = {
"type": "assistant",
"uuid": "a1-think",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
"stop_reason": None,
"stop_sequence": None,
},
}
ASST_1_TOOL = {
"type": "assistant",
"uuid": "a1-tool",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_sdk_aaa",
"type": "message",
"content": [
{
"type": "tool_use",
"id": "tu1",
"name": "Bash",
"input": {"command": "ls"},
}
],
"stop_reason": "tool_use",
"stop_sequence": None,
},
}
TOOL_RESULT_1 = {
"type": "user",
"uuid": "tr1",
"parentUuid": "a1-tool",
"message": {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "tu1",
"content": "file1.py\nfile2.py",
}
],
},
}
ASST_1_TEXT = {
"type": "assistant",
"uuid": "a1-text",
"parentUuid": "tr1",
"message": {
"role": "assistant",
"id": "msg_sdk_bbb",
"type": "message",
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Progress entries (should be stripped during upload)
PROGRESS_1 = {
"type": "progress",
"uuid": "prog1",
"parentUuid": "a1-tool",
"data": {"type": "bash_progress", "stdout": "running ls..."},
}
# Second user message
USER_2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1-text",
"message": {"role": "user", "content": "Show me file1.py"},
}
ASST_2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_sdk_ccc",
"type": "message",
"content": [{"type": "text", "text": "Here is file1.py content..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# --- Compaction summary (written by CLI after context compaction) ---
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {
"role": "user",
"content": (
"Summary: User asked about project files. Found file1.py and file2.py. "
"User then asked to see file1.py."
),
},
}
# Post-compaction assistant response
POST_COMPACT_ASST = {
"type": "assistant",
"uuid": "a3",
"parentUuid": "cs1",
"message": {
"role": "assistant",
"id": "msg_sdk_ddd",
"type": "message",
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# Post-compaction user follow-up
USER_3 = {
"type": "user",
"uuid": "u3",
"parentUuid": "a3",
"message": {"role": "user", "content": "Now show file2.py"},
}
ASST_3 = {
"type": "assistant",
"uuid": "a4",
"parentUuid": "u3",
"message": {
"role": "assistant",
"id": "msg_sdk_eee",
"type": "message",
"content": [{"type": "text", "text": "Here is file2.py..."}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
# ---------------------------------------------------------------------------
# E2E test
# ---------------------------------------------------------------------------
class TestCompactionE2E:
def _write_session_file(self, session_dir, entries):
"""Write a CLI session JSONL file."""
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(*entries))
return path
def test_full_compaction_lifecycle(self, tmp_path, monkeypatch):
"""Simulate the complete service.py compaction flow.
Timeline:
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
2. Current turn: download → load_previous
3. User sends "Now show file2.py" → append_user
4. SDK starts streaming response
5. Mid-stream: PreCompact hook fires (context too large)
6. CLI writes compaction summary to session file
7. Next SDK message → emit_start (spinner)
8. Following message → emit_end (CompactionResult)
9. read_compacted_entries reads the session file
10. replace_entries syncs TranscriptBuilder
11. More assistant messages appended
12. Export → upload → next turn downloads it
"""
# --- Setup CLI projects directory ---
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
previous_transcript = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
builder = TranscriptBuilder()
builder.load_previous(previous_transcript)
assert builder.entry_count == 7
# --- Step 3: User sends new query ---
builder.append_user("Now show file2.py")
assert builder.entry_count == 8
# --- Step 4: SDK starts streaming ---
builder.append_assistant(
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
model="claude-sonnet-4-20250514",
)
assert builder.entry_count == 9
# --- Step 5-6: PreCompact fires, CLI writes session file ---
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
USER_3,
ASST_3,
],
)
# --- Step 7: CompactionTracker receives PreCompact hook ---
tracker = CompactionTracker()
session = ChatSession.new(user_id="test-user")
tracker.on_compact(str(session_file))
# --- Step 8: Next SDK message arrives → emit_start ---
start_events = tracker.emit_start_if_ready()
assert len(start_events) == 3
assert isinstance(start_events[0], StreamStartStep)
assert isinstance(start_events[1], StreamToolInputStart)
assert isinstance(start_events[2], StreamToolInputAvailable)
# Verify tool_call_id is set
tool_call_id = start_events[1].toolCallId
assert tool_call_id.startswith("compaction-")
# --- Step 9: Following message → emit_end ---
result = _run(tracker.emit_end_if_ready(session))
assert result.just_ended is True
assert result.transcript_path == str(session_file)
assert len(result.events) == 2
assert isinstance(result.events[0], StreamToolOutputAvailable)
assert isinstance(result.events[1], StreamFinishStep)
# Verify same tool_call_id
assert result.events[0].toolCallId == tool_call_id
# Session should have compaction messages persisted
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[1].role == "tool"
# --- Step 10: read_compacted_entries + replace_entries ---
compacted = read_compacted_entries(str(session_file))
assert compacted is not None
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
assert len(compacted) == 4
assert compacted[0]["uuid"] == "cs1"
assert compacted[0]["isCompactSummary"] is True
# Replace builder state with compacted entries
old_count = builder.entry_count
builder.replace_entries(compacted)
assert builder.entry_count == 4 # Only compacted entries
assert builder.entry_count < old_count # Compaction reduced entries
# --- Step 11: More assistant messages after compaction ---
builder.append_assistant(
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
model="claude-sonnet-4-20250514",
stop_reason="end_turn",
)
assert builder.entry_count == 5
# --- Step 12: Export for upload ---
output = builder.to_jsonl()
assert output # Not empty
output_entries = [json.loads(line) for line in output.strip().split("\n")]
assert len(output_entries) == 5
# Verify structure:
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
assert output_entries[0]["type"] == "summary"
assert output_entries[0].get("isCompactSummary") is True
assert output_entries[0]["uuid"] == "cs1"
assert output_entries[1]["uuid"] == "a3"
assert output_entries[2]["uuid"] == "u3"
assert output_entries[3]["uuid"] == "a4"
assert output_entries[4]["type"] == "assistant"
# Verify parent chain is intact
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
assert output_entries[4]["parentUuid"] == "a4" # new → a4
# --- Step 13: Roundtrip — next turn loads this export ---
builder2 = TranscriptBuilder()
builder2.load_previous(output)
assert builder2.entry_count == 5
# isCompactSummary survives roundtrip
output2 = builder2.to_jsonl()
first_entry = json.loads(output2.strip().split("\n")[0])
assert first_entry.get("isCompactSummary") is True
# Can append more messages
builder2.append_user("What about file3.py?")
assert builder2.entry_count == 6
final_output = builder2.to_jsonl()
last_entry = json.loads(final_output.strip().split("\n")[-1])
assert last_entry["type"] == "user"
# Parented to the last entry from previous turn
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
def test_double_compaction_within_session(self, tmp_path, monkeypatch):
"""Two compactions in the same session (across reset_for_query)."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
tracker = CompactionTracker()
session = ChatSession.new(user_id="test")
builder = TranscriptBuilder()
# --- First query with compaction ---
builder.append_user("first question")
builder.append_assistant([{"type": "text", "text": "first answer"}])
# Write session file for first compaction
first_summary = {
"type": "summary",
"uuid": "cs-first",
"isCompactSummary": True,
"message": {"role": "user", "content": "First compaction summary"},
}
first_post = {
"type": "assistant",
"uuid": "a-first",
"parentUuid": "cs-first",
"message": {"role": "assistant", "content": "first post-compact"},
}
file1 = session_dir / "session1.jsonl"
file1.write_text(_make_jsonl(first_summary, first_post))
tracker.on_compact(str(file1))
tracker.emit_start_if_ready()
result1 = _run(tracker.emit_end_if_ready(session))
assert result1.just_ended is True
compacted1 = read_compacted_entries(str(file1))
assert compacted1 is not None
builder.replace_entries(compacted1)
assert builder.entry_count == 2
# --- Reset for second query ---
tracker.reset_for_query()
# --- Second query with compaction ---
builder.append_user("second question")
builder.append_assistant([{"type": "text", "text": "second answer"}])
second_summary = {
"type": "summary",
"uuid": "cs-second",
"isCompactSummary": True,
"message": {"role": "user", "content": "Second compaction summary"},
}
second_post = {
"type": "assistant",
"uuid": "a-second",
"parentUuid": "cs-second",
"message": {"role": "assistant", "content": "second post-compact"},
}
file2 = session_dir / "session2.jsonl"
file2.write_text(_make_jsonl(second_summary, second_post))
tracker.on_compact(str(file2))
tracker.emit_start_if_ready()
result2 = _run(tracker.emit_end_if_ready(session))
assert result2.just_ended is True
compacted2 = read_compacted_entries(str(file2))
assert compacted2 is not None
builder.replace_entries(compacted2)
assert builder.entry_count == 2 # Only second compaction entries
# Export and verify
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["uuid"] == "cs-second"
assert entries[0].get("isCompactSummary") is True
def test_strip_progress_then_load_then_compact_roundtrip(
self, tmp_path, monkeypatch
):
"""Full pipeline: strip → load → compact → replace → export → reload.
This tests the exact sequence that happens across two turns:
Turn 1: SDK produces transcript with progress entries
Upload: strip_progress_entries removes progress, upload to cloud
Turn 2: Download → load_previous → compaction fires → replace → export
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
"""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# --- Turn 1: SDK produces raw transcript ---
raw_content = _make_jsonl(
USER_1,
ASST_1_THINKING,
ASST_1_TOOL,
PROGRESS_1,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
)
# Strip progress for upload
stripped = strip_progress_entries(raw_content)
stripped_entries = [
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
]
# Progress should be gone
assert not any(e.get("type") == "progress" for e in stripped_entries)
assert len(stripped_entries) == 7 # 8 - 1 progress
# --- Turn 2: Download stripped, load, compaction happens ---
builder = TranscriptBuilder()
builder.load_previous(stripped)
assert builder.entry_count == 7
builder.append_user("Now show file2.py")
builder.append_assistant(
[{"type": "text", "text": "Reading file2.py..."}],
model="claude-sonnet-4-20250514",
)
# CLI writes session file with compaction
session_file = self._write_session_file(
session_dir,
[
USER_1,
ASST_1_TOOL,
TOOL_RESULT_1,
ASST_1_TEXT,
USER_2,
ASST_2,
COMPACT_SUMMARY,
POST_COMPACT_ASST,
],
)
compacted = read_compacted_entries(str(session_file))
assert compacted is not None
builder.replace_entries(compacted)
# Append post-compaction message
builder.append_user("Thanks!")
output = builder.to_jsonl()
# --- Turn 3: Fresh load of Turn 2 export ---
builder3 = TranscriptBuilder()
builder3.load_previous(output)
# Should have: compact_summary + post_compact_asst + "Thanks!"
assert builder3.entry_count == 3
# Compact summary survived the full pipeline
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
assert first.get("isCompactSummary") is True
assert first["type"] == "summary"

View File

@@ -1,715 +0,0 @@
"""File reference protocol for tool call inputs.
Allows the LLM to pass a file reference instead of embedding large content
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
arguments before the tool is executed.
Protocol
--------
@@agptfile:<uri>[<start>-<end>]
``<uri>`` (required)
- ``workspace://<file_id>`` — workspace file by ID
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
- ``workspace:///<path>`` — workspace file by virtual path
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
- Any absolute path that resolves inside the E2B sandbox
(``/home/user/...``) when a sandbox is active
``[<start>-<end>]`` (optional)
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
Omit to read the entire file.
Examples
--------
@@agptfile:workspace://abc123
@@agptfile:workspace://abc123[10-50]
@@agptfile:workspace:///reports/q1.md
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
@@agptfile:/home/user/script.sh
"""
import itertools
import logging
import os
import re
from dataclasses import dataclass
from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
get_workspace_manager,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.util.file import parse_workspace_uri
from backend.util.file_content_parser import (
BINARY_FORMATS,
MIME_TO_FORMAT,
PARSE_EXCEPTIONS,
infer_format_from_uri,
parse_file_content,
)
from backend.util.type import MediaFileType
class FileRefExpansionError(Exception):
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
Separating this from inline substitution lets callers (e.g. the MCP tool
wrapper) block tool execution and surface a helpful error to the model
rather than passing an ``[file-ref error: …]`` string as actual input.
"""
logger = logging.getLogger(__name__)
FILE_REF_PREFIX = "@@agptfile:"
# Matches: @@agptfile:<uri>[start-end]?
# Group 1 URI; must start with '/' (absolute path) or 'workspace://'
# Group 2 start line (optional)
# Group 3 end line (optional)
_FILE_REF_RE = re.compile(
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
)
# Maximum characters returned for a single file reference expansion.
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
# Maximum raw byte size for bare ref structured parsing (10 MB).
_MAX_BARE_REF_BYTES = 10_000_000
@dataclass
class FileRef:
uri: str
start_line: int | None # 1-indexed, inclusive
end_line: int | None # 1-indexed, inclusive
# ---------------------------------------------------------------------------
# Public API (top-down: main functions first, helpers below)
# ---------------------------------------------------------------------------
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
expand references embedded in larger strings.
"""
m = _FILE_REF_RE.fullmatch(text.strip())
if not m:
return None
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if start is not None and start < 1:
return None
if end is not None and end < 1:
return None
if start is not None and end is not None and end < start:
return None
return FileRef(uri=m.group(1), start_line=start, end_line=end)
async def read_file_bytes(
uri: str,
user_id: str | None,
session: ChatSession,
) -> bytes:
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
Raises :class:`ValueError` if the URI cannot be resolved.
"""
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_workspace_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
data = await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except (PermissionError, OSError) as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
except (AttributeError, TypeError, RuntimeError) as exc:
# AttributeError/TypeError: workspace manager returned an
# unexpected type or interface; RuntimeError: async runtime issues.
logger.warning("Unexpected error reading %s: %s", plain, exc)
raise ValueError(f"Failed to read {plain}: {exc}") from exc
# NOTE: Workspace API does not support pre-read size checks;
# the full file is loaded before the size guard below.
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
)
return data
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
# Read with a one-byte overshoot to detect files that exceed the limit
# without a separate os.path.getsize call (avoids TOCTOU race).
with open(resolved, "rb") as fh:
data = fh.read(_MAX_BARE_REF_BYTES + 1)
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
f"limit {_MAX_BARE_REF_BYTES})"
)
return data
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except OSError as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
if sandbox is not None:
try:
remote = resolve_sandbox_path(plain)
except ValueError as exc:
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
data = bytes(await sandbox.files.read(remote, format="bytes"))
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
except Exception as exc:
# E2B SDK raises SandboxException subclasses (NotFoundException,
# TimeoutException, NotEnoughSpaceException, etc.) which don't
# inherit from standard exceptions. Import lazily to avoid a
# hard dependency on e2b at module level.
try:
from e2b.exceptions import SandboxException # noqa: PLC0415
if isinstance(exc, SandboxException):
raise ValueError(
f"Failed to read from sandbox: {plain}: {exc}"
) from exc
except ImportError:
pass
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
# so they surface as real bugs rather than being silently masked.
raise
# NOTE: E2B sandbox API does not support pre-read size checks;
# the full file is loaded before the size guard below.
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
)
return data
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
)
async def resolve_file_ref(
ref: FileRef,
user_id: str | None,
session: ChatSession,
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: ChatSession,
*,
raise_on_error: bool = False,
) -> str:
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
Non-reference text is passed through unchanged.
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
where partial expansion is acceptable.
If *raise_on_error* is ``True``, any resolution failure raises
:class:`FileRefExpansionError` immediately so the caller can block the
operation and surface a clean error to the model.
"""
if FILE_REF_PREFIX not in text:
return text
result: list[str] = []
last_end = 0
total_chars = 0
for m in _FILE_REF_RE.finditer(text):
result.append(text[last_end : m.start()])
start = int(m.group(2)) if m.group(2) else None
end = int(m.group(3)) if m.group(3) else None
if (start is not None and start < 1) or (end is not None and end < 1):
msg = f"line numbers must be >= 1: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
if start is not None and end is not None and end < start:
msg = f"end line must be >= start line: {m.group(0)}"
if raise_on_error:
raise FileRefExpansionError(msg)
result.append(f"[file-ref error: {msg}]")
last_end = m.end()
continue
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
try:
content = await resolve_file_ref(ref, user_id, session)
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
# remaining == 0 means the budget was exactly exhausted by the
# previous ref. The elif below (len > remaining) won't catch
# this since 0 > 0 is false, so we need the <= 0 check.
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
content = content[:remaining] + "\n... [total budget exhausted]"
total_chars += len(content)
result.append(content)
except ValueError as exc:
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
if raise_on_error:
raise FileRefExpansionError(str(exc)) from exc
result.append(f"[file-ref error: {exc}]")
last_end = m.end()
result.append(text[last_end:])
return "".join(result)
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: ChatSession,
*,
input_schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
**Bare references** (the entire argument value is a single
``@@agptfile:...`` token with no surrounding text) are resolved and then
parsed according to the file's extension or MIME type. See
:mod:`backend.util.file_content_parser` for the full list of supported
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
When *input_schema* is provided and the target property has
``"type": "string"``, structured parsing is skipped — the raw file content
is returned as a plain string so blocks receive the original text.
If the format is unrecognised or parsing fails, the content is returned as
a plain string (the fallback).
**Embedded references** (``@@agptfile:`` mixed with other text) always
produce a plain string — structured parsing only applies to bare refs.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
response that lets the model correct the reference before retrying.
"""
if not args:
return args
properties = (input_schema or {}).get("properties", {})
async def _expand(
value: Any,
*,
prop_schema: dict[str, Any] | None = None,
) -> Any:
"""Recursively expand a single argument value.
Strings are checked for ``@@agptfile:`` references and expanded
(bare refs get structured parsing; embedded refs get inline
substitution). Dicts and lists are traversed recursively,
threading the corresponding sub-schema from *prop_schema* so
that nested fields also receive correct type-aware expansion.
Non-string scalars pass through unchanged.
"""
if isinstance(value, str):
ref = parse_file_ref(value)
if ref is not None:
# MediaFileType fields: return the raw URI immediately —
# no file reading, no format inference, no content parsing.
if _is_media_file_field(prop_schema):
return ref.uri
fmt = infer_format_from_uri(ref.uri)
# Workspace URIs by ID (workspace://abc123) have no extension.
# When the MIME fragment is also missing, fall back to the
# workspace file manager's metadata for format detection.
if fmt is None and ref.uri.startswith("workspace://"):
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
# Not a bare ref — do normal inline expansion.
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
# When the schema says this is an object but doesn't define
# inner properties, skip expansion — the caller (e.g.
# RunBlockTool) will expand with the actual nested schema.
if (
prop_schema is not None
and prop_schema.get("type") == "object"
and "properties" not in prop_schema
):
return value
nested_props = (prop_schema or {}).get("properties", {})
return {
k: await _expand(v, prop_schema=nested_props.get(k))
for k, v in value.items()
}
if isinstance(value, list):
items_schema = (prop_schema or {}).get("items")
return [await _expand(item, prop_schema=items_schema) for item in value]
return value
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
# ---------------------------------------------------------------------------
# Private helpers (used by the public functions above)
# ---------------------------------------------------------------------------
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive).
When the requested range extends beyond the file, a note is appended
so the LLM knows it received the entire remaining content.
"""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
total = len(lines)
s = (start - 1) if start is not None else 0
e = end if end is not None else total
selected = list(itertools.islice(lines, s, e))
result = "".join(selected)
if end is not None and end > total:
result += f"\n[Note: file has only {total} lines]\n"
return result
def _to_str(content: str | bytes) -> str:
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
if isinstance(content, str):
return content
return content.decode("utf-8", errors="replace")
def _check_content_size(content: str | bytes) -> None:
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
(``_expand_bare_ref``) can unify all resolution errors into a single
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
For ``bytes``, the length is the byte count directly. For ``str``,
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
mean the byte size can be up to 4x the character count.
"""
if isinstance(content, bytes):
size = len(content)
else:
char_len = len(content)
# Fast lower bound: UTF-8 byte count >= char count.
# If char count already exceeds the limit, reject immediately
# without allocating an encoded copy.
if char_len > _MAX_BARE_REF_BYTES:
size = char_len # real byte size is even larger
# Fast upper bound: each char is at most 4 UTF-8 bytes.
# If worst-case is still under the limit, skip encoding entirely.
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
return
else:
# Edge case: char count is under limit but multibyte chars
# might push byte count over. Encode to get exact size.
size = len(content.encode("utf-8"))
if size > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large for structured parsing "
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
)
async def _infer_format_from_workspace(
uri: str,
user_id: str | None,
session: ChatSession,
) -> str | None:
"""Look up workspace file metadata to infer the format.
Workspace URIs by ID (``workspace://abc123``) have no file extension.
When the MIME fragment is also absent, we query the workspace file
manager for the file's stored MIME type and original filename.
"""
if not user_id:
return None
try:
ws = parse_workspace_uri(uri)
manager = await get_workspace_manager(user_id, session.session_id)
info = await (
manager.get_file_info(ws.file_ref)
if not ws.is_path
else manager.get_file_info_by_path(ws.file_ref)
)
if info is None:
return None
# Try MIME type first, then filename extension.
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
except (
ValueError,
FileNotFoundError,
OSError,
PermissionError,
AttributeError,
TypeError,
):
# Expected failures: bad URI, missing file, permission denied, or
# workspace manager returning unexpected types. Propagate anything
# else (e.g. programming errors) so they don't get silently swallowed.
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
return None
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
if prop_schema is None:
return False
return (
prop_schema.get("type") == "string"
and prop_schema.get("format") == MediaFileType.string_format
)
async def _expand_bare_ref(
ref: FileRef,
fmt: str | None,
user_id: str | None,
session: ChatSession,
prop_schema: dict[str, Any] | None,
) -> Any:
"""Resolve and parse a bare ``@@agptfile:`` reference.
This is the structured-parsing path: the file is read, optionally parsed
according to *fmt*, and adapted to the target *prop_schema*.
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
Note: MediaFileType fields (format: "file") are handled earlier in
``_expand`` to avoid unnecessary format inference and file I/O.
"""
try:
if fmt is not None and fmt in BINARY_FORMATS:
# Binary formats need raw bytes, not UTF-8 text.
# Line ranges are meaningless for binary formats (parquet/xlsx)
# — ignore them and parse full bytes. Warn so the caller/model
# knows the range was silently dropped.
if ref.start_line is not None or ref.end_line is not None:
logger.warning(
"Line range [%s-%s] ignored for binary format %s (%s); "
"binary formats are always parsed in full.",
ref.start_line,
ref.end_line,
fmt,
ref.uri,
)
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
else:
content = await resolve_file_ref(ref, user_id, session)
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# For known formats this rejects files >10 MB before parsing.
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
# but this check still guards the parsing path which has no char limit.
# _check_content_size raises ValueError, which we unify here just like
# resolution errors above.
try:
_check_content_size(content)
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# When the schema declares this parameter as "string",
# return raw file content — don't parse into a structured
# type that would need json.dumps() serialisation.
expect_string = (prop_schema or {}).get("type") == "string"
if expect_string:
if isinstance(content, bytes):
raise FileRefExpansionError(
f"Cannot use {fmt} file as text input: "
f"binary formats (parquet, xlsx) must be passed "
f"to a block that accepts structured data (list/object), "
f"not a string-typed parameter."
)
return content
if fmt is not None:
# Use strict mode for binary formats so we surface the
# actual error (e.g. missing pyarrow/openpyxl, corrupt
# file) instead of silently returning garbled bytes.
strict = fmt in BINARY_FORMATS
try:
parsed = parse_file_content(content, fmt, strict=strict)
except PARSE_EXCEPTIONS as exc:
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
# Normalize bytes fallback to str so tools never
# receive raw bytes when parsing fails.
if isinstance(parsed, bytes):
parsed = _to_str(parsed)
return _adapt_to_schema(parsed, prop_schema)
# Unknown format — return as plain string, but apply
# the same per-ref character limit used by inline refs
# to prevent injecting unexpectedly large content.
text = _to_str(content)
if len(text) > _MAX_EXPAND_CHARS:
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
return text
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
"""Adapt a parsed file value to better fit the target schema type.
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
that doesn't match the block's expected type, this function converts it to
a more useful representation instead of relying on pydantic's generic
coercion (which can produce awkward results like flattened dicts → lists).
Returns *parsed* unchanged when no adaptation is needed.
"""
if prop_schema is None:
return parsed
target_type = prop_schema.get("type")
# Dict → array: delegate to helper.
if isinstance(parsed, dict) and target_type == "array":
return _adapt_dict_to_array(parsed, prop_schema)
# List → object: delegate to helper (raises for non-tabular lists).
if isinstance(parsed, list) and target_type == "object":
return _adapt_list_to_object(parsed)
# Tabular list → Any (no type): convert to list of dicts.
# Blocks like FindInDictionaryBlock have `input: Any` which produces
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
# for key lookup, but [{col: val}, ...] works with FindInDict's
# list-of-dicts branch (line 195-199 in data_manipulation.py).
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
return _tabular_to_list_of_dicts(parsed)
return parsed
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
"""Adapt a parsed dict to an array-typed field.
Extracts list-valued entries when the target item type is ``array``,
passes through unchanged when item type is ``string`` (lets pydantic error),
or wraps in ``[parsed]`` as a fallback.
"""
items_type = (prop_schema.get("items") or {}).get("type")
if items_type == "array":
# Target is List[List[Any]] — extract list-typed values from the
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
list_values = [v for v in parsed.values() if isinstance(v, list)]
if list_values:
return list_values
if items_type == "string":
# Target is List[str] — wrapping a dict would give [dict]
# which can't coerce to strings. Return unchanged and let
# pydantic surface a clear validation error.
return parsed
# Fallback: wrap in a single-element list so the block gets [dict]
# instead of pydantic flattening keys/values into a flat list.
return [parsed]
def _adapt_list_to_object(parsed: list) -> Any:
"""Adapt a parsed list to an object-typed field.
Converts tabular lists to column-dicts; raises for non-tabular lists.
"""
if _is_tabular(parsed):
return _tabular_to_column_dict(parsed)
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
# be meaningfully coerced to an object. Raise explicitly so callers
# get a clear error rather than pydantic silently wrapping the list.
raise FileRefExpansionError(
"Cannot adapt a non-tabular list to an object-typed field. "
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
)
def _is_tabular(parsed: Any) -> bool:
"""Check if parsed data is in tabular format: [[header], [row1], ...].
Uses isinstance checks because this is a structural type guard on
opaque parser output (Any), not duck typing. A Protocol wouldn't
help here — we need to verify exact list-of-lists shape.
"""
if not isinstance(parsed, list) or len(parsed) < 2:
return False
header = parsed[0]
if not isinstance(header, list) or not header:
return False
if not all(isinstance(h, str) for h in header):
return False
return all(isinstance(row, list) for row in parsed[1:])
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
Ragged rows (fewer columns than the header) get None for missing values.
Extra values beyond the header length are silently dropped.
"""
header = parsed[0]
return [
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
for row in parsed[1:]
]
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
Ragged rows (fewer columns than the header) get None for missing values,
ensuring all columns have equal length.
"""
header = parsed[0]
return {
col: [row[i] if i < len(row) else None for row in parsed[1:]]
for i, col in enumerate(header)
}

View File

@@ -1,521 +0,0 @@
"""Integration tests for @@agptfile: reference expansion in tool calls.
These tests verify the end-to-end behaviour of the file reference protocol:
- Parsing @@agptfile: tokens from tool arguments
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
- The extended Read tool handler (workspace:// pass-through via session context)
No real LLM or database is required; workspace reads are stubbed where needed.
"""
from __future__ import annotations
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.sdk.file_ref import (
FileRef,
expand_file_refs_in_args,
expand_file_refs_in_string,
read_file_bytes,
resolve_file_ref,
)
from backend.copilot.sdk.tool_adapter import _read_file_handler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(session_id: str = "integ-sess") -> MagicMock:
s = MagicMock()
s.session_id = session_id
return s
# ---------------------------------------------------------------------------
# Local-file resolution (sdk_cwd)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path():
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
# Write a test file inside sdk_cwd
test_file = os.path.join(sdk_cwd, "hello.txt")
with open(test_file, "w") as f:
f.write("line1\nline2\nline3\n")
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=None, end_line=None)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line1\nline2\nline3\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_local_path_with_line_range():
"""resolve_file_ref respects line ranges for local files."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "multi.txt")
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
with open(test_file, "w") as f:
f.writelines(lines)
session = _make_session()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
ref = FileRef(uri=test_file, start_line=3, end_line=5)
content = await resolve_file_ref(ref, user_id="u1", session=session)
assert content == "line3\nline4\nline5\n"
@pytest.mark.asyncio
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var:
mock_cwd_var.get.return_value = sdk_cwd
mock_sandbox_var.get.return_value = None
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
with pytest.raises(ValueError, match="not allowed"):
await resolve_file_ref(ref, user_id="u1", session=_make_session())
# ---------------------------------------------------------------------------
# expand_file_refs_in_string — integration with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_string_with_real_file():
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "data.txt")
with open(test_file, "w") as f:
f.write("hello world\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"Content: @@agptfile:{test_file}",
user_id="u1",
session=_make_session(),
)
assert result == "Content: hello world\n"
@pytest.mark.asyncio
async def test_expand_string_missing_file_is_surfaced_inline():
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
with tempfile.TemporaryDirectory() as sdk_cwd:
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_string(
f"@@agptfile:{missing}",
user_id="u1",
session=_make_session(),
)
assert "[file-ref error:" in result
assert "not found" in result.lower() or "not allowed" in result.lower()
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — dict traversal with real files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_expand_args_replaces_file_ref_in_nested_dict():
"""Nested @@agptfile: references in args are fully expanded."""
with tempfile.TemporaryDirectory() as sdk_cwd:
file_a = os.path.join(sdk_cwd, "a.txt")
file_b = os.path.join(sdk_cwd, "b.txt")
with open(file_a, "w") as f:
f.write("AAA")
with open(file_b, "w") as f:
f.write("BBB")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{
"outer": {
"content_a": f"@@agptfile:{file_a}",
"content_b": f"start @@agptfile:{file_b} end",
},
"count": 42,
},
user_id="u1",
session=_make_session(),
)
assert result["outer"]["content_a"] == "AAA"
assert result["outer"]["content_b"] == "start BBB end"
assert result["count"] == 42
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — bare ref structured parsing
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_bare_ref_json_returns_parsed_dict():
"""Bare ref to a .json file returns parsed dict, not raw string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "data.json")
with open(json_file, "w") as f:
f.write('{"key": "value", "count": 42}')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{json_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == {"key": "value", "count": 42}
@pytest.mark.asyncio
async def test_bare_ref_csv_returns_parsed_table():
"""Bare ref to a .csv file returns list[list[str]] table."""
with tempfile.TemporaryDirectory() as sdk_cwd:
csv_file = os.path.join(sdk_cwd, "data.csv")
with open(csv_file, "w") as f:
f.write("Name,Score\nAlice,90\nBob,85")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"input": f"@@agptfile:{csv_file}"},
user_id="u1",
session=_make_session(),
)
assert result["input"] == [
["Name", "Score"],
["Alice", "90"],
["Bob", "85"],
]
@pytest.mark.asyncio
async def test_bare_ref_unknown_extension_returns_string():
"""Bare ref to a file with unknown extension returns plain string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
txt_file = os.path.join(sdk_cwd, "readme.txt")
with open(txt_file, "w") as f:
f.write("plain text content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{txt_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == "plain text content"
assert isinstance(result["data"], str)
@pytest.mark.asyncio
async def test_bare_ref_invalid_json_falls_back_to_string():
"""Bare ref to a .json file with invalid JSON falls back to string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "bad.json")
with open(json_file, "w") as f:
f.write("not valid json {{{")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{json_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == "not valid json {{{"
assert isinstance(result["data"], str)
@pytest.mark.asyncio
async def test_embedded_ref_always_returns_string_even_for_json():
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "data.json")
with open(json_file, "w") as f:
f.write('{"key": "value"}')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"prefix @@agptfile:{json_file} suffix"},
user_id="u1",
session=_make_session(),
)
assert isinstance(result["data"], str)
assert result["data"].startswith("prefix ")
assert result["data"].endswith(" suffix")
@pytest.mark.asyncio
async def test_bare_ref_yaml_returns_parsed_dict():
"""Bare ref to a .yaml file returns parsed dict."""
with tempfile.TemporaryDirectory() as sdk_cwd:
yaml_file = os.path.join(sdk_cwd, "config.yaml")
with open(yaml_file, "w") as f:
f.write("name: test\ncount: 42\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"config": f"@@agptfile:{yaml_file}"},
user_id="u1",
session=_make_session(),
)
assert result["config"] == {"name": "test", "count": 42}
@pytest.mark.asyncio
async def test_bare_ref_binary_with_line_range_ignores_range():
"""Bare ref to a binary file (.parquet) with line range parses the full file.
Binary formats (parquet, xlsx) ignore line ranges — the full content is
parsed and the range is silently dropped with a log warning.
"""
try:
import pandas as pd
except ImportError:
pytest.skip("pandas not installed")
try:
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
except ImportError:
pytest.skip("pyarrow not installed")
with tempfile.TemporaryDirectory() as sdk_cwd:
parquet_file = os.path.join(sdk_cwd, "data.parquet")
import io as _io
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
buf = _io.BytesIO()
df.to_parquet(buf, index=False)
with open(parquet_file, "wb") as f:
f.write(buf.getvalue())
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
# Line range [1-2] should be silently ignored for binary formats.
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{parquet_file}[1-2]"},
user_id="u1",
session=_make_session(),
)
# Full file is returned despite the line range.
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
@pytest.mark.asyncio
async def test_bare_ref_toml_returns_parsed_dict():
"""Bare ref to a .toml file returns parsed dict."""
with tempfile.TemporaryDirectory() as sdk_cwd:
toml_file = os.path.join(sdk_cwd, "config.toml")
with open(toml_file, "w") as f:
f.write('name = "test"\ncount = 42\n')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"config": f"@@agptfile:{toml_file}"},
user_id="u1",
session=_make_session(),
)
assert result["config"] == {"name": "test", "count": 42}
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_handler_local_file():
"""_read_file_handler reads a local file when it's within sdk_cwd."""
with tempfile.TemporaryDirectory() as sdk_cwd:
test_file = os.path.join(sdk_cwd, "read_test.txt")
lines = [f"L{i}\n" for i in range(1, 6)]
with open(test_file, "w") as f:
f.writelines(lines)
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj_var, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd_var.get.return_value = sdk_cwd
mock_proj_var.get.return_value = ""
result = await _read_file_handler(
{"file_path": test_file, "offset": 0, "limit": 5}
)
assert not result["isError"]
text = result["content"][0]["text"]
assert "L1" in text
assert "L5" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri():
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
mock_session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
)
assert not result["isError"], result["content"][0]["text"]
text = result["content"][0]["text"]
assert "workspace file content" in text
assert "line two" in text
@pytest.mark.asyncio
async def test_read_file_handler_workspace_uri_no_session():
"""_read_file_handler returns error when workspace:// is used without session."""
with patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=(None, None),
):
result = await _read_file_handler({"file_path": "workspace://some-id"})
assert result["isError"]
assert "session" in result["content"][0]["text"].lower()
@pytest.mark.asyncio
async def test_read_file_handler_access_denied():
"""_read_file_handler rejects paths outside allowed locations."""
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox, patch(
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", _make_session()),
):
mock_cwd.get.return_value = "/tmp/safe-dir"
mock_sandbox.get.return_value = None
result = await _read_file_handler({"file_path": "/etc/passwd"})
assert result["isError"]
assert "not allowed" in result["content"][0]["text"].lower()
# ---------------------------------------------------------------------------
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_file_bytes_workspace_virtual_path():
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
session = _make_session()
mock_manager = AsyncMock()
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
assert result == b"virtual path content"
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_sandbox_branch():
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
session = _make_session()
mock_sandbox = AsyncMock()
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
result = await read_file_bytes("/home/user/script.sh", None, session)
assert result == b"sandbox content"
mock_sandbox.files.read.assert_awaited_once_with(
"/home/user/script.sh", format="bytes"
)
@pytest.mark.asyncio
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
session = _make_session()
mock_sandbox = AsyncMock()
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
"backend.copilot.context._current_sandbox"
) as mock_sandbox_var, patch(
"backend.copilot.context._current_project_dir"
) as mock_proj:
mock_cwd.get.return_value = ""
mock_sandbox_var.get.return_value = mock_sandbox
mock_proj.get.return_value = ""
with pytest.raises(ValueError, match="not allowed"):
await read_file_bytes("/etc/passwd", None, session)

File diff suppressed because it is too large Load Diff

View File

@@ -1,59 +0,0 @@
## MCP Tool Guide
### Workflow
`run_mcp_tool` follows a two-step pattern:
1. **Discover** — call with only `server_url` to list available tools on the server.
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
### Known hosted MCP servers
Use these URLs directly without asking the user:
| Service | URL |
|---|---|
| Notion | `https://mcp.notion.com/mcp` |
| Linear | `https://mcp.linear.app/mcp` |
| Stripe | `https://mcp.stripe.com` |
| Intercom | `https://mcp.intercom.com/mcp` |
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
For other services, search the MCP registry API:
```http
GET https://registry.modelcontextprotocol.io/v0/servers?q=<search_term>
```
Each result includes a `remotes` array with the exact server URL to use.
### Important: Check blocks first
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
Google Calendar, Gmail, etc.) that work without MCP setup.
Only use `run_mcp_tool` when:
- The service is in the known hosted MCP servers list above, OR
- You searched `find_block` first and found no matching blocks
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
or from the `remotes[].url` field in MCP registry search results.
### Authentication
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
### Communication style
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
Use plain, friendly language instead:
| Instead of… | Say… |
|---|---|
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".

View File

@@ -536,12 +536,10 @@ async def test_wait_for_stash_signaled():
result = await wait_for_stash(timeout=1.0)
assert result is True
pto = _pto.get()
assert pto is not None
assert pto.get("WebSearch") == ["result data"]
assert _pto.get({}).get("WebSearch") == ["result data"]
# Cleanup
_pto.set({})
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@@ -556,7 +554,7 @@ async def test_wait_for_stash_timeout():
assert result is False
# Cleanup
_pto.set({})
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@@ -575,12 +573,10 @@ async def test_wait_for_stash_already_stashed():
assert result is True
# But the stash itself is populated
pto = _pto.get()
assert pto is not None
assert pto.get("Read") == ["file contents"]
assert _pto.get({}).get("Read") == ["file contents"]
# Cleanup
_pto.set({})
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)

View File

@@ -10,13 +10,12 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.copilot.context import is_allowed_local_path
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)
@@ -127,7 +126,7 @@ def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[str], None] | None = None,
on_compact: Callable[[], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -142,7 +141,6 @@ def create_security_hooks(
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_compact: Callback invoked when SDK starts compacting context.
Receives the transcript_path from the hook input.
Returns:
Hooks configuration dict for ClaudeAgentOptions
@@ -302,21 +300,11 @@ def create_security_hooks(
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
# Sanitize untrusted input before logging to prevent log injection
transcript_path = (
str(input_data.get("transcript_path", ""))
.replace("\n", "")
.replace("\r", "")
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
trigger,
user_id,
transcript_path,
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
if on_compact is not None:
on_compact(transcript_path)
on_compact()
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {

View File

@@ -9,9 +9,8 @@ import os
import pytest
from backend.copilot.context import _current_project_dir
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
@@ -121,6 +120,8 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
@@ -132,14 +133,16 @@ def test_read_tool_results_allowed():
_current_project_dir.reset(token)
def test_read_claude_projects_settings_json_denied():
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
assert not _is_denied(result)
finally:
_current_project_dir.reset(token)
@@ -354,3 +357,76 @@ async def test_task_slot_released_on_failure(_hooks):
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

View File

@@ -29,7 +29,6 @@ from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
@@ -61,8 +60,9 @@ from ..service import (
_generate_session_title,
_is_langfuse_configured,
)
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.e2b_sandbox import get_or_create_sandbox
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .response_adapter import SDKResponseAdapter
@@ -77,7 +77,6 @@ from .tool_adapter import (
from .transcript import (
cleanup_cli_project_dir,
download_transcript,
read_compacted_entries,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -457,6 +456,31 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
def _is_tool_error_or_denial(content: str | None) -> bool:
"""Check if a tool message content indicates an error or denial.
Currently unused — ``_format_conversation_context`` includes all tool
results. Kept as a utility for future selective filtering.
"""
if not content:
return False
lower = content.lower()
return any(
marker in lower
for marker in (
"[security]",
"cannot be bypassed",
"not allowed",
"not supported", # background-task denial
"maximum", # subtask-limit denial
"denied",
"blocked",
"failed to", # internal tool execution failures
'"iserror": true', # MCP protocol error flag
)
)
async def _build_query_message(
current_message: str,
session: ChatSession,
@@ -565,7 +589,7 @@ async def _prepare_file_attachments(
return empty
try:
manager = await get_workspace_manager(user_id, session_id)
manager = await get_manager(user_id, session_id)
except Exception:
logger.warning(
"Failed to create workspace manager for file attachments",
@@ -760,29 +784,28 @@ async def stream_chat_completion_sdk(
async def _setup_e2b():
"""Set up E2B sandbox if configured, return sandbox or None."""
if not (e2b_api_key := config.active_e2b_api_key):
if config.use_e2b_sandbox:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
)
return None
try:
return await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
on_timeout=config.e2b_sandbox_on_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
if config.use_e2b_sandbox and not config.e2b_api_key:
logger.warning(
"[E2B] [%s] E2B sandbox enabled but no API key configured "
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
if config.use_e2b_sandbox and config.e2b_api_key:
try:
return await get_or_create_sandbox(
session_id,
api_key=config.e2b_api_key,
template=config.e2b_sandbox_template,
timeout=config.e2b_sandbox_timeout,
)
except Exception as e2b_err:
logger.error(
"[E2B] [%s] Setup failed: %s",
session_id[:12],
e2b_err,
exc_info=True,
)
return None
async def _fetch_transcript():
@@ -814,6 +837,7 @@ async def stream_chat_completion_sdk(
system_prompt = base_system_prompt + get_sdk_supplement(
use_e2b=use_e2b, cwd=sdk_cwd
)
# Process transcript download result
transcript_msg_count = 0
if dl:
@@ -878,11 +902,6 @@ async def stream_chat_completion_sdk(
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
sid = session_id[:12] if session_id else "?"
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
@@ -891,7 +910,6 @@ async def stream_chat_completion_sdk(
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
"stderr": _on_stderr,
}
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
@@ -1046,7 +1064,6 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
ended_with_stream_error = True
yield StreamError(
errorText=f"SDK stream error: {stream_err}",
code="sdk_stream_error",
@@ -1065,19 +1082,6 @@ async def stream_chat_completion_sdk(
len(adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
# so we can debug Anthropic API 400s surfaced by the CLI.
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Race-condition fix: SDK hooks (PostToolUse) are
# executed asynchronously via start_soon() — the next
# message can arrive before the hook stashes output.
@@ -1131,26 +1135,9 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with the
# CLI's active context so they stay identical.
compact_result = await compaction.emit_end_if_ready(session)
for ev in compact_result.events:
# Emit compaction end if SDK finished compacting
for ev in await compaction.emit_end_if_ready(session):
yield ev
# After replace_entries, skip append_assistant for this
# sdk_msg — the CLI session file already contains it,
# so appending again would create a duplicate.
entries_replaced = False
if compact_result.just_ended:
compacted = await asyncio.to_thread(
read_compacted_entries,
compact_result.transcript_path,
)
if compacted is not None:
transcript_builder.replace_entries(
compacted, log_prefix=log_prefix
)
entries_replaced = True
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
@@ -1237,11 +1224,10 @@ async def stream_chat_completion_sdk(
tool_call_id=response.toolCallId,
)
)
if not entries_replaced:
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
@@ -1251,9 +1237,7 @@ async def stream_chat_completion_sdk(
# any stashed tool results from the previous turn are
# recorded first, preserving the required API order:
# assistant(tool_use) → tool_result → assistant(text).
# Skip if replace_entries just ran — the CLI session
# file already contains this message.
if isinstance(sdk_msg, AssistantMessage) and not entries_replaced:
if isinstance(sdk_msg, AssistantMessage):
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
@@ -1432,25 +1416,14 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
# --- Pause E2B sandbox to stop billing between turns ---
# Fire-and-forget: pausing is best-effort and must not block the
# response or the transcript upload. The task is anchored to
# _background_tasks to prevent garbage collection.
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
# round-trip — e2b_sandbox is the live object from this turn.
if e2b_sandbox is not None:
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# TranscriptBuilder is the single source of truth. It mirrors the
# CLI's active context: on compaction, replace_entries() syncs it
# with the compacted session file. No CLI file read needed here.
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception.
# The transcript represents the COMPLETE active context (atomic).
if config.claude_agent_use_resume and user_id and session is not None:
try:
# Build complete transcript from captured SDK messages
transcript_content = transcript_builder.to_jsonl()
entry_count = transcript_builder.entry_count
if not transcript_content:
logger.warning(
@@ -1460,15 +1433,18 @@ async def stream_chat_completion_sdk(
logger.warning(
"%s Transcript invalid, skipping upload (entries=%d)",
log_prefix,
entry_count,
transcript_builder.entry_count,
)
else:
logger.info(
"%s Uploading transcript (entries=%d, bytes=%d)",
"%s Uploading complete transcript (entries=%d, bytes=%d)",
log_prefix,
entry_count,
transcript_builder.entry_count,
len(transcript_content),
)
# Shield upload from cancellation - let it complete even if
# the finally block is interrupted. No timeout to avoid race
# conditions where backgrounded uploads overwrite newer transcripts.
await asyncio.shield(
upload_transcript(
user_id=user_id,

View File

@@ -1,10 +1,9 @@
"""Tests for SDK service helpers."""
import asyncio
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
@@ -20,7 +19,7 @@ class _FakeFileInfo:
size_bytes: int
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
class TestPrepareFileAttachments:
@@ -213,7 +212,7 @@ class TestPromptSupplement:
# Workflows are now in individual tool descriptions (not separate sections)
# Check that key workflow concepts appear in tool descriptions
assert "agent_json" in docs or "find_block" in docs
assert "suggested_goal" in docs or "clarifying_questions" in docs
assert "run_mcp_tool" in docs
def test_baseline_supplement_completeness(self):
@@ -232,48 +231,6 @@ class TestPromptSupplement:
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
def test_pause_task_scheduled_before_transcript_upload(self):
"""Pause is scheduled as a background task before transcript upload begins.
The finally block in stream_response_sdk does:
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
(2) await asyncio.shield(upload_transcript(...)) — awaited
Scheduling pause via create_task before awaiting upload ensures:
- Pause never blocks transcript upload (billing stops concurrently)
- On E2B timeout, pause silently fails; upload proceeds unaffected
"""
call_order: list[str] = []
async def _mock_pause(sandbox, session_id):
call_order.append("pause")
async def _mock_upload(**kwargs):
call_order.append("upload")
async def _simulate_teardown():
"""Mirror the service.py finally block teardown sequence."""
sandbox = MagicMock()
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
# Yielding to the event loop here lets the pause task start concurrently.
await _mock_upload(
user_id="u", session_id="test-sess", content="x", message_count=1
)
await task
asyncio.run(_simulate_teardown())
# Both must run; pause is scheduled before upload starts
assert "pause" in call_order
assert "upload" in call_order
# create_task schedules pause, then upload is awaited — pause runs
# concurrently during upload's first yield. The ordering guarantee is
# that create_task is CALLED before upload is AWAITED (see source order).
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement

View File

@@ -9,29 +9,14 @@ import itertools
import json
import logging
import os
import re
import uuid
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from backend.copilot.context import (
_current_project_dir,
_current_sandbox,
_current_sdk_cwd,
_current_session,
_current_user_id,
_encode_cwd_for_cli,
get_execution_context,
get_sdk_cwd,
is_allowed_local_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import (
FileRefExpansionError,
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -43,13 +28,84 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
@@ -93,6 +149,24 @@ def set_execution_context(
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
@@ -185,11 +259,7 @@ async def _execute_tool_sync(
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response.
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
"""
"""Execute a tool synchronously and return MCP-formatted response."""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
@@ -250,50 +320,42 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit.
"""Read a local file with optional offset/limit.
Supports ``workspace://`` URIs (delegated to the workspace manager) and
local paths within the session's allowed directories (sdk_cwd + tool-results).
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
"""
file_path = args.get("file_path", "")
offset = max(0, int(args.get("offset", 0)))
limit = max(1, int(args.get("limit", 2000)))
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
def _mcp_err(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": True}
def _mcp_ok(text: str) -> dict[str, Any]:
return {"content": [{"type": "text", "text": text}], "isError": False}
if file_path.startswith("workspace://"):
user_id, session = get_execution_context()
if session is None:
return _mcp_err("workspace:// file references require an active session")
try:
raw = await read_file_bytes(file_path, user_id, session)
except ValueError as exc:
return _mcp_err(str(exc))
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp_ok(numbered)
if not is_allowed_local_path(file_path, get_sdk_cwd()):
return _mcp_err(f"Path not allowed: {file_path}")
if not is_allowed_local_path(file_path):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return _mcp_ok("".join(selected))
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
except FileNotFoundError:
return _mcp_err(f"File not found: {file_path}")
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return _mcp_err(f"Error reading file: {e}")
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
@@ -347,30 +409,14 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
:func:`get_sdk_disallowed_tools`.
"""
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
def _truncating(fn, tool_name: str):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
Also expands ``@@agptfile:`` references in args so every registered tool
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(
args, user_id, session, input_schema=input_schema
)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
"Ensure the file exists before referencing it. "
"For sandbox paths use bash_exec to verify the file exists first; "
"for workspace files use a workspace:// URI."
)
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)
@@ -391,12 +437,11 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
schema = _build_input_schema(base_tool)
decorated = tool(
tool_name,
base_tool.description,
schema,
)(_truncating(handler, tool_name, input_schema=schema))
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.

View File

@@ -2,12 +2,12 @@
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,

View File

@@ -13,10 +13,8 @@ filesystem for self-hosted) — no DB column needed.
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from backend.util import json
@@ -84,11 +82,7 @@ def strip_progress_entries(content: str) -> str:
parent = entry.get("parentUuid", "")
if uid:
uuid_to_parent[uid] = parent
if (
entry.get("type", "") in STRIPPABLE_TYPES
and uid
and not entry.get("isCompactSummary")
):
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
stripped_uuids.add(uid)
# Second pass: keep non-stripped entries, reparenting where needed.
@@ -112,9 +106,7 @@ def strip_progress_entries(content: str) -> str:
if not isinstance(entry, dict):
result_lines.append(line)
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
if entry.get("type", "") in STRIPPABLE_TYPES:
continue
uid = entry.get("uuid", "")
if uid in reparented:
@@ -145,155 +137,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _projects_base() -> str:
"""Return the resolved path to the CLI's projects directory."""
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
return os.path.realpath(os.path.join(config_dir, "projects"))
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
Returns ``None`` if the path would escape the projects base.
"""
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
projects_base = _projects_base()
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] Project dir escaped projects base: %s", project_dir
)
return None
return project_dir
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
try:
resolved_base = Path(project_dir).resolve()
except OSError as e:
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
return []
result: list[Path] = []
for candidate in Path(project_dir).glob("*.jsonl"):
try:
resolved = candidate.resolve()
if resolved.is_relative_to(resolved_base):
result.append(resolved)
except (OSError, RuntimeError) as e:
logger.debug(
"[Transcript] Skipping invalid CLI session candidate %s: %s",
candidate,
e,
)
return result
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
"""Read compacted entries from the CLI session file after compaction.
Parses the JSONL file line-by-line, finds the ``isCompactSummary: true``
entry, and returns it plus all entries after it.
The CLI writes the compaction summary BEFORE sending the next message,
so the file is guaranteed to be flushed by the time we read it.
Returns a list of parsed dicts, or ``None`` if the file cannot be read
or no compaction summary is found.
"""
if not transcript_path:
return None
projects_base = _projects_base()
real_path = os.path.realpath(transcript_path)
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] transcript_path outside projects base: %s", transcript_path
)
return None
try:
content = Path(real_path).read_text()
except OSError as e:
logger.warning(
"[Transcript] Failed to read session file %s: %s", transcript_path, e
)
return None
lines = content.strip().split("\n")
compact_idx: int | None = None
for idx, line in enumerate(lines):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("isCompactSummary"):
compact_idx = idx # don't break — find the LAST summary
if compact_idx is None:
logger.debug("[Transcript] No compaction summary found in %s", transcript_path)
return None
entries: list[dict] = []
for line in lines[compact_idx:]:
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if isinstance(entry, dict):
entries.append(entry)
logger.info(
"[Transcript] Read %d compacted entries from %s (summary at line %d)",
len(entries),
transcript_path,
compact_idx + 1,
)
return entries
def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any compaction.
The CLI writes its session transcript to
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
Since each SDK turn uses a unique ``sdk_cwd``, there should be
exactly one ``.jsonl`` file in that directory.
Returns the file content, or ``None`` if not found.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = _safe_glob_jsonl(project_dir)
if not jsonl_files:
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
return None
# Pick the most recently modified file (should be only one per turn).
try:
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
except OSError as e:
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
return None
try:
content = session_file.read_text()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
session_file,
len(content),
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
@@ -301,15 +144,25 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
import shutil
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
else:
logger.debug("[Transcript] Project dir not found: %s", project_dir)
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
def write_transcript_to_tempfile(
@@ -406,27 +259,24 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
)
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects.
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
``local://workspace_id/file_id/filename``. Since we use deterministic
arguments we can reconstruct the same path for download/delete without
having stored the return value.
"""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = parts
wid, fid, fname = _storage_path_parts(user_id, session_id)
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
return f"gcs://{backend.bucket_name}/{blob}"
return f"local://{wid}/{fid}/{fname}"
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects."""
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path for the companion .meta.json file."""
return _build_path_from_parts(
_meta_storage_path_parts(user_id, session_id), backend
)
else:
# LocalWorkspaceStorage returns local://{relative_path}
return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(
@@ -531,7 +381,15 @@ async def download_transcript(
message_count = 0
uploaded_at = 0.0
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
from backend.util.workspace_storage import GCSWorkspaceStorage
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
if isinstance(storage, GCSWorkspaceStorage):
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
meta_path = f"gcs://{storage.bucket_name}/{blob}"
else:
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
@@ -548,11 +406,7 @@ async def download_transcript(
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript and its metadata from bucket storage.
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
"""
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
@@ -560,14 +414,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
try:
await storage.delete(path)
logger.info("[Transcript] Deleted transcript for session %s", session_id)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
except Exception as e:
logger.warning("[Transcript] Failed to delete transcript: %s", e)
# Also delete the companion .meta.json to avoid orphaned metadata.
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
await storage.delete(meta_path)
logger.info("[Transcript] Deleted metadata for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
logger.warning(f"[Transcript] Failed to delete transcript: {e}")

View File

@@ -30,7 +30,6 @@ class TranscriptEntry(BaseModel):
type: str
uuid: str
parentUuid: str | None
isCompactSummary: bool | None = None
message: dict[str, Any]
@@ -54,24 +53,6 @@ class TranscriptBuilder:
return self._entries[-1].message.get("id", "")
return ""
@staticmethod
def _parse_entry(data: dict) -> TranscriptEntry | None:
"""Parse a single transcript entry, filtering strippable types.
Returns ``None`` for entries that should be skipped (strippable types
that are not compaction summaries).
"""
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
return None
return TranscriptEntry(
type=entry_type,
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
isCompactSummary=data.get("isCompactSummary") or None,
message=data.get("message", {}),
)
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
@@ -97,9 +78,18 @@ class TranscriptBuilder:
)
continue
entry = self._parse_entry(data)
if entry is None:
# Load all non-strippable entries (user/assistant/system/etc.)
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES:
continue
entry = TranscriptEntry(
type=data["type"],
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
message=data.get("message", {}),
)
self._entries.append(entry)
self._last_uuid = entry.uuid
@@ -172,43 +162,6 @@ class TranscriptBuilder:
)
self._last_uuid = msg_uuid
def replace_entries(
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
) -> None:
"""Replace all entries with compacted entries from the CLI session file.
Called after mid-stream compaction so TranscriptBuilder mirrors the
CLI's active context (compaction summary + post-compaction entries).
Builds the new list first and validates it's non-empty before swapping,
so corrupt input cannot wipe the conversation history.
"""
new_entries: list[TranscriptEntry] = []
for data in compacted_entries:
entry = self._parse_entry(data)
if entry is not None:
new_entries.append(entry)
if not new_entries:
logger.warning(
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
log_prefix,
len(compacted_entries),
len(self._entries),
)
return
old_count = len(self._entries)
self._entries = new_entries
self._last_uuid = new_entries[-1].uuid
logger.info(
"%s TranscriptBuilder compacted: %d entries -> %d entries",
log_prefix,
old_count,
len(self._entries),
)
def to_jsonl(self) -> str:
"""Export complete context as JSONL.

View File

@@ -1,23 +1,15 @@
"""Unit tests for JSONL transcript management utilities."""
import os
from unittest.mock import AsyncMock, patch
import pytest
from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
_cli_project_dir,
delete_transcript,
read_cli_session_file,
read_compacted_entries,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
def _make_jsonl(*entries: dict) -> str:
@@ -290,610 +282,3 @@ class TestStripProgressEntries:
lines = result.strip().split("\n")
asst_entry = json.loads(lines[-1])
assert asst_entry["parentUuid"] == "u1" # reparented
# --- read_cli_session_file ---
class TestReadCliSessionFile:
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
"""read_cli_session_file returns None when no .jsonl files exist."""
# Create a project dir with no jsonl files
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
assert read_cli_session_file("/fake/cwd") is None
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
"""read_cli_session_file returns the content of a single .jsonl file."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
jsonl_file = project_dir / "session.jsonl"
jsonl_file.write_text("line1\nline2\n")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
result = read_cli_session_file("/fake/cwd")
assert result == "line1\nline2\n"
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
"""read_cli_session_file skips symlinks that escape the project dir."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
# Create a file outside the project dir
outside = tmp_path / "outside"
outside.mkdir()
outside_file = outside / "evil.jsonl"
outside_file.write_text("should not be read\n")
# Symlink from inside project_dir to outside file
symlink = project_dir / "evil.jsonl"
symlink.symlink_to(outside_file)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
# The symlink target resolves outside project_dir, so it should be skipped
result = read_cli_session_file("/fake/cwd")
assert result is None
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
config_dir = tmp_path / "config"
config_dir.mkdir()
projects_dir = config_dir / "projects"
projects_dir.mkdir()
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# Create a symlink inside projects/ that points outside of it.
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
# cwd whose encoded form matches the symlink name we create.
evil_target = tmp_path / "escaped"
evil_target.mkdir()
# The encoded form of "/evil/cwd" is "-evil-cwd"
symlink_path = projects_dir / "-evil-cwd"
symlink_path.symlink_to(evil_target)
result = _cli_project_dir("/evil/cwd")
assert result is None
# --- delete_transcript ---
class TestDeleteTranscript:
@pytest.mark.asyncio
async def test_deletes_both_jsonl_and_meta(self):
"""delete_transcript removes both the .jsonl and .meta.json files."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock()
with patch(
"backend.util.workspace_storage.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 2
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
assert any(p.endswith(".jsonl") for p in paths)
assert any(p.endswith(".meta.json") for p in paths)
@pytest.mark.asyncio
async def test_continues_on_jsonl_delete_failure(self):
"""If .jsonl delete fails, .meta.json delete is still attempted."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[Exception("jsonl delete failed"), None]
)
with patch(
"backend.util.workspace_storage.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
# Should not raise
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 2
@pytest.mark.asyncio
async def test_handles_meta_delete_failure(self):
"""If .meta.json delete fails, no exception propagates."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[None, Exception("meta delete failed")]
)
with patch(
"backend.util.workspace_storage.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
# Should not raise
await delete_transcript("user-123", "session-456")
# --- read_compacted_entries ---
COMPACT_SUMMARY = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "assistant", "content": "compacted context"},
}
POST_COMPACT_ASST = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "response after compaction"},
}
class TestReadCompactedEntries:
def test_returns_summary_and_entries_after(self, tmp_path, monkeypatch):
"""File with isCompactSummary entry returns summary + entries after."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
pre_compact = {"type": "user", "uuid": "u1", "message": {"role": "user"}}
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(pre_compact, COMPACT_SUMMARY, POST_COMPACT_ASST))
result = read_compacted_entries(str(path))
assert result is not None
assert len(result) == 2
assert result[0]["isCompactSummary"] is True
assert result[1]["uuid"] == "a2"
def test_no_compact_summary_returns_none(self, tmp_path, monkeypatch):
"""File without isCompactSummary returns None."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(USER_MSG, ASST_MSG))
result = read_compacted_entries(str(path))
assert result is None
def test_file_not_found_returns_none(self, tmp_path, monkeypatch):
"""Non-existent file returns None."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
projects_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
result = read_compacted_entries(str(projects_dir / "missing.jsonl"))
assert result is None
def test_empty_path_returns_none(self):
"""Empty string path returns None."""
result = read_compacted_entries("")
assert result is None
def test_malformed_json_lines_skipped(self, tmp_path, monkeypatch):
"""Malformed JSON lines are skipped gracefully."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
path = session_dir / "session.jsonl"
content = "not valid json\n" + json.dumps(COMPACT_SUMMARY) + "\n"
content += "also bad\n" + json.dumps(POST_COMPACT_ASST) + "\n"
path.write_text(content)
result = read_compacted_entries(str(path))
assert result is not None
assert len(result) == 2 # summary + post-compact assistant
def test_multiple_compact_summaries_uses_last(self, tmp_path, monkeypatch):
"""When multiple isCompactSummary entries exist, uses the last one
(most recent compaction)."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
second_summary = {
"type": "summary",
"uuid": "cs2",
"isCompactSummary": True,
"message": {"role": "assistant", "content": "second summary"},
}
path = session_dir / "session.jsonl"
path.write_text(_make_jsonl(COMPACT_SUMMARY, POST_COMPACT_ASST, second_summary))
result = read_compacted_entries(str(path))
assert result is not None
# Last summary found, so only cs2 returned
assert len(result) == 1
assert result[0]["uuid"] == "cs2"
def test_path_outside_projects_base_returns_none(self, tmp_path, monkeypatch):
"""Transcript path outside the projects directory is rejected."""
config_dir = tmp_path / "config"
(config_dir / "projects").mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
evil_file = tmp_path / "evil.jsonl"
evil_file.write_text(_make_jsonl(COMPACT_SUMMARY))
result = read_compacted_entries(str(evil_file))
assert result is None
# --- TranscriptBuilder.replace_entries ---
class TestTranscriptBuilderReplaceEntries:
def test_replaces_existing_entries(self):
"""replace_entries replaces all entries with compacted ones."""
builder = TranscriptBuilder()
builder.append_user("hello")
builder.append_assistant([{"type": "text", "text": "world"}])
assert builder.entry_count == 2
compacted = [
{
"type": "user",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "compacted summary"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "response"},
},
]
builder.replace_entries(compacted)
assert builder.entry_count == 2
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["uuid"] == "cs1"
assert entries[1]["uuid"] == "a1"
def test_filters_strippable_types(self):
"""Strippable types are filtered out during replace."""
builder = TranscriptBuilder()
compacted = [
{
"type": "user",
"uuid": "cs1",
"message": {"role": "user", "content": "compacted summary"},
},
{"type": "progress", "uuid": "p1", "message": {}},
{"type": "summary", "uuid": "s1", "message": {}},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "hi"},
},
]
builder.replace_entries(compacted)
assert builder.entry_count == 2 # progress and summary were filtered
def test_maintains_last_uuid_chain(self):
"""After replace, _last_uuid is the last entry's uuid."""
builder = TranscriptBuilder()
compacted = [
{
"type": "user",
"uuid": "cs1",
"message": {"role": "user", "content": "compacted summary"},
},
{
"type": "assistant",
"uuid": "a1",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "hi"},
},
]
builder.replace_entries(compacted)
# Appending a new user message should chain to a1
builder.append_user("next question")
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[-1]["parentUuid"] == "a1"
def test_empty_entries_list_keeps_existing(self):
"""Replacing with empty list keeps existing entries (safety check)."""
builder = TranscriptBuilder()
builder.append_user("hello")
builder.replace_entries([])
# Empty input is treated as corrupt — existing entries preserved
assert builder.entry_count == 1
assert not builder.is_empty
# --- TranscriptBuilder.load_previous with compacted content ---
class TestTranscriptBuilderLoadPreviousCompacted:
def test_preserves_compact_summary_entry(self):
"""load_previous preserves isCompactSummary entries even though
their type is 'summary' (which is in STRIPPABLE_TYPES)."""
compacted_content = _make_jsonl(COMPACT_SUMMARY, POST_COMPACT_ASST)
builder = TranscriptBuilder()
builder.load_previous(compacted_content)
assert builder.entry_count == 2
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["type"] == "summary"
assert entries[0]["uuid"] == "cs1"
assert entries[1]["uuid"] == "a2"
def test_strips_regular_summary_entries(self):
"""Regular summary entries (without isCompactSummary) are still stripped."""
regular_summary = {"type": "summary", "uuid": "s1", "message": {"content": "x"}}
content = _make_jsonl(regular_summary, POST_COMPACT_ASST)
builder = TranscriptBuilder()
builder.load_previous(content)
assert builder.entry_count == 1 # Only the assistant entry
# --- End-to-end compaction flow (simulates service.py) ---
class TestCompactionFlowIntegration:
"""Simulate the full compaction flow as it happens in service.py:
1. TranscriptBuilder loads a previous transcript (download)
2. New messages are appended (user query + assistant response)
3. CompactionTracker fires (PreCompact hook → emit_start → emit_end)
4. read_compacted_entries reads the CLI session file
5. TranscriptBuilder.replace_entries syncs with CLI state
6. Final to_jsonl() produces the correct output (upload)
"""
def test_full_compaction_roundtrip(self, tmp_path, monkeypatch):
"""Full roundtrip: load → append → compact → replace → export."""
# Setup: create a CLI session file with pre-compact + compaction entries
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# Simulate a transcript with old messages, then a compaction summary
old_user = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "old question"},
}
old_asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {"role": "assistant", "content": "old answer"},
}
compact_summary = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "compacted summary of conversation"},
}
post_compact_asst = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "response after compaction"},
}
session_file = session_dir / "session.jsonl"
session_file.write_text(
_make_jsonl(old_user, old_asst, compact_summary, post_compact_asst)
)
# Step 1: TranscriptBuilder loads previous transcript (simulates download)
# The previous transcript would have the OLD entries (pre-compaction)
previous_transcript = _make_jsonl(old_user, old_asst)
builder = TranscriptBuilder()
builder.load_previous(previous_transcript)
assert builder.entry_count == 2
# Step 2: New messages appended during the current query
builder.append_user("new question")
builder.append_assistant([{"type": "text", "text": "new answer"}])
assert builder.entry_count == 4
# Step 3: read_compacted_entries reads the CLI session file
compacted = read_compacted_entries(str(session_file))
assert compacted is not None
assert len(compacted) == 2 # compact_summary + post_compact_asst
assert compacted[0]["isCompactSummary"] is True
# Step 4: replace_entries syncs builder with CLI state
builder.replace_entries(compacted)
assert builder.entry_count == 2 # Only compacted entries now
# Step 5: Append post-compaction messages (continuing the stream)
builder.append_user("follow-up question")
assert builder.entry_count == 3
# Step 6: Export and verify
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert len(entries) == 3
# First entry is the compaction summary
assert entries[0]["type"] == "summary"
assert entries[0]["uuid"] == "cs1"
# Second is the post-compact assistant
assert entries[1]["uuid"] == "a2"
# Third is our follow-up, parented to the last compacted entry
assert entries[2]["type"] == "user"
assert entries[2]["parentUuid"] == "a2"
def test_compaction_preserves_chain_across_multiple_compactions(
self, tmp_path, monkeypatch
):
"""Two compactions: first compacts old history, second compacts the first."""
config_dir = tmp_path / "config"
projects_dir = config_dir / "projects"
session_dir = projects_dir / "proj"
session_dir.mkdir(parents=True)
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# First compaction
first_summary = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "first summary"},
}
mid_asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "mid response"},
}
# Second compaction (compacts the first summary + mid_asst)
second_summary = {
"type": "summary",
"uuid": "cs2",
"isCompactSummary": True,
"message": {"role": "user", "content": "second summary"},
}
final_asst = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "cs2",
"message": {"role": "assistant", "content": "final response"},
}
session_file = session_dir / "session.jsonl"
session_file.write_text(
_make_jsonl(first_summary, mid_asst, second_summary, final_asst)
)
# read_compacted_entries should find the LAST summary
compacted = read_compacted_entries(str(session_file))
assert compacted is not None
assert len(compacted) == 2 # second_summary + final_asst
assert compacted[0]["uuid"] == "cs2"
# Apply to builder
builder = TranscriptBuilder()
builder.append_user("old stuff")
builder.append_assistant([{"type": "text", "text": "old response"}])
builder.replace_entries(compacted)
assert builder.entry_count == 2
# New message chains correctly
builder.append_user("after second compaction")
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[-1]["parentUuid"] == "a2"
def test_strip_progress_preserves_compact_summaries(self):
"""strip_progress_entries doesn't strip isCompactSummary entries
even though their type is 'summary' (in STRIPPABLE_TYPES)."""
compact_summary = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "compacted"},
}
regular_summary = {"type": "summary", "uuid": "s1", "message": {"content": "x"}}
progress = {"type": "progress", "uuid": "p1", "data": {"stdout": "..."}}
user = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
}
content = _make_jsonl(compact_summary, regular_summary, progress, user)
stripped = strip_progress_entries(content)
stripped_entries = [
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
]
uuids = [e.get("uuid") for e in stripped_entries]
# compact_summary kept, regular_summary stripped, progress stripped, user kept
assert "cs1" in uuids # compact summary preserved
assert "s1" not in uuids # regular summary stripped
assert "p1" not in uuids # progress stripped
assert "u1" in uuids # user kept
def test_builder_load_then_replace_then_export_roundtrip(self):
"""Load a compacted transcript, replace with new compaction, export.
Simulates two consecutive turns with compaction each time."""
# Turn 1: load compacted transcript
compact1 = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "summary v1"},
}
asst1 = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "cs1",
"message": {"role": "assistant", "content": "response 1"},
}
builder = TranscriptBuilder()
builder.load_previous(_make_jsonl(compact1, asst1))
assert builder.entry_count == 2
# Turn 1: append new messages
builder.append_user("question")
builder.append_assistant([{"type": "text", "text": "answer"}])
assert builder.entry_count == 4
# Turn 1: compaction fires — replace with new compacted state
compact2 = {
"type": "summary",
"uuid": "cs2",
"isCompactSummary": True,
"message": {"role": "user", "content": "summary v2"},
}
asst2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "cs2",
"message": {"role": "assistant", "content": "continuing"},
}
builder.replace_entries([compact2, asst2])
assert builder.entry_count == 2
# Export (this goes to cloud storage for next turn's download)
output = builder.to_jsonl()
lines = [json.loads(line) for line in output.strip().split("\n")]
assert lines[0]["uuid"] == "cs2"
assert lines[0]["type"] == "summary"
assert lines[1]["uuid"] == "a2"
# Turn 2: fresh builder loads the exported transcript
builder2 = TranscriptBuilder()
builder2.load_previous(output)
assert builder2.entry_count == 2
builder2.append_user("turn 2 question")
output2 = builder2.to_jsonl()
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
assert lines2[-1]["parentUuid"] == "a2"

View File

@@ -28,24 +28,10 @@ logger = logging.getLogger(__name__)
config = ChatConfig()
settings = Settings()
_client: LangfuseAsyncOpenAI | None = None
_langfuse = None
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
def _get_openai_client() -> LangfuseAsyncOpenAI:
global _client
if _client is None:
_client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
return _client
def _get_langfuse():
global _langfuse
if _langfuse is None:
_langfuse = get_client()
return _langfuse
langfuse = get_client()
# Default system prompt used when Langfuse is not configured
# Provides minimal baseline tone and personality - all workflow, tools, and
@@ -98,7 +84,7 @@ async def _get_system_prompt_template(context: str) -> str:
else "latest"
)
prompt = await asyncio.to_thread(
_get_langfuse().get_prompt,
langfuse.get_prompt,
config.langfuse_prompt_name,
label=label,
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
@@ -172,7 +158,7 @@ async def _generate_session_title(
"environment": settings.config.app_env.value,
}
response = await _get_openai_client().chat.completions.create(
response = await client.chat.completions.create(
model=config.title_model,
messages=[
{

View File

@@ -23,11 +23,6 @@ from typing import Any, Literal
import orjson
from backend.api.model import CopilotCompletionPayload
from backend.data.notification_bus import (
AsyncRedisNotificationEventBus,
NotificationEvent,
)
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
@@ -43,7 +38,6 @@ from .response_model import (
logger = logging.getLogger(__name__)
config = ChatConfig()
_notification_bus = AsyncRedisNotificationEventBus()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_sessions: dict[str, asyncio.Task] = {}
@@ -751,29 +745,6 @@ async def mark_session_completed(
# Clean up local session reference if exists
_local_sessions.pop(session_id, None)
# Publish copilot completion notification via WebSocket
if meta:
parsed = _parse_session_meta(meta, session_id)
if parsed.user_id:
try:
await _notification_bus.publish(
NotificationEvent(
user_id=parsed.user_id,
payload=CopilotCompletionPayload(
type="copilot_completion",
event="session_completed",
session_id=session_id,
status=status,
),
)
)
except Exception as e:
logger.warning(
f"Failed to publish copilot completion notification "
f"for session {session_id}: {e}"
)
return True

View File

@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -20,10 +19,7 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
@@ -36,7 +32,6 @@ from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
@@ -69,13 +64,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"continue_run_block": ContinueRunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
@@ -88,9 +80,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),

View File

@@ -32,9 +32,8 @@ import shutil
import tempfile
from typing import Any
from backend.copilot.context import get_workspace_manager
from backend.copilot.model import ChatSession
from backend.util.request import validate_url_host
from backend.util.request import validate_url
from .base import BaseTool
from .models import (
@@ -44,6 +43,7 @@ from .models import (
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
@@ -194,7 +194,7 @@ async def _save_browser_state(
),
}
manager = await get_workspace_manager(user_id, session.session_id)
manager = await get_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
@@ -218,7 +218,7 @@ async def _restore_browser_state(
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_workspace_manager(user_id, session.session_id)
manager = await get_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
@@ -235,7 +235,7 @@ async def _restore_browser_state(
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url_host(url)
await validate_url(url, trusted_origins=[])
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_workspace_manager(user_id, session_name)
manager = await get_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
)
try:
await validate_url_host(url)
await validate_url(url, trusted_origins=[])
except ValueError as e:
return ErrorResponse(
message=str(e),

View File

@@ -68,18 +68,17 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
# ---------------------------------------------------------------------------
# SSRF protection via shared validate_url_host (backend.util.request)
# SSRF protection via shared validate_url (backend.util.request)
# ---------------------------------------------------------------------------
# Patch target: validate_url_host is imported directly into agent_browser's
# module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
# Patch target: validate_url is imported directly into agent_browser's module scope.
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
class TestSsrfViaValidateUrl:
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
"""Verify that browser_navigate uses validate_url for SSRF protection.
We mock validate_url_host itself (not the low-level socket) so these tests
We mock validate_url itself (not the low-level socket) so these tests
exercise the integration point, not the internals of request.py
(which has its own thorough test suite in request_test.py).
"""
@@ -90,7 +89,7 @@ class TestSsrfViaValidateUrl:
@pytest.mark.asyncio
async def test_blocked_ip_returns_blocked_url_error(self):
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.side_effect = ValueError(
"Access to blocked IP 10.0.0.1 is not allowed."
@@ -125,8 +124,8 @@ class TestSsrfViaValidateUrl:
assert result.error == "blocked_url"
@pytest.mark.asyncio
async def test_validate_url_host_called_without_trusted_hostnames(self):
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
async def test_validate_url_called_with_empty_trusted_origins(self):
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
mock_validate.return_value = (object(), False, ["1.2.3.4"])
with patch(
@@ -144,7 +143,7 @@ class TestSsrfViaValidateUrl:
session=self.session,
url="https://example.com",
)
mock_validate.assert_called_once_with("https://example.com")
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
# ---------------------------------------------------------------------------
@@ -897,7 +896,7 @@ class TestHasLocalSession:
# _save_browser_state
# ---------------------------------------------------------------------------
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
def _make_mock_manager():

View File

@@ -1,15 +1,20 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
@@ -22,20 +27,25 @@ from .core import (
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .validation import AgentFixer, AgentValidator
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
"AgentFixer",
"AgentValidator",
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"get_agent_as_json",
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
@@ -44,6 +54,7 @@ __all__ = [
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",

View File

@@ -1,66 +0,0 @@
"""Block management for agent generation.
Provides cached access to block metadata for validation and fixing.
"""
import logging
from typing import Any, Type
from backend.blocks import get_blocks as get_block_classes
from backend.blocks._base import Block
logger = logging.getLogger(__name__)
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
# ---------------------------------------------------------------------------
# Module-level caches
# ---------------------------------------------------------------------------
_blocks_cache: list[dict[str, Any]] | None = None
def reset_block_caches() -> None:
"""Reset all module-level caches (useful after updating block descriptions)."""
global _blocks_cache
_blocks_cache = None
# ---------------------------------------------------------------------------
# 1. get_blocks_as_dicts
# ---------------------------------------------------------------------------
def get_blocks_as_dicts() -> list[dict[str, Any]]:
"""Get all available blocks as dicts (cached after first call).
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
id, name, description, inputSchema, outputSchema, categories,
staticOutput, costs, contributors, uiType.
Returns:
List of block info dicts.
"""
global _blocks_cache
if _blocks_cache is not None:
return _blocks_cache
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
blocks: list[dict[str, Any]] = []
for block_cls in block_classes.values():
try:
instance = block_cls()
info = instance.get_info().model_dump()
# Use optimized description if available (loaded at startup)
if instance.optimized_description:
info["description"] = instance.optimized_description
blocks.append(info)
except Exception:
logger.warning(
"Failed to load block info for %s, skipping",
getattr(block_cls, "__name__", "unknown"),
exc_info=True,
)
_blocks_cache = blocks
logger.info("Cached %d block dicts", len(blocks))
return _blocks_cache

View File

@@ -10,7 +10,13 @@ from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .helpers import UUID_RE_STR
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
logger = logging.getLogger(__name__)
@@ -72,7 +78,38 @@ class DecompositionResult(TypedDict, total=False):
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def extract_uuids_from_text(text: str) -> list[str]:
@@ -516,6 +553,69 @@ async def enrich_library_agents_from_steps(
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
)
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
@@ -692,3 +792,70 @@ async def get_agent_as_json(
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -0,0 +1,165 @@
"""Dummy Agent Generator for testing.
Returns mock responses matching the format expected from the external service.
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from typing import Any
logger = logging.getLogger(__name__)
# Dummy decomposition result (instructions type)
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
"type": "instructions",
"steps": [
{
"description": "Get input from user",
"action": "input",
"block_name": "AgentInputBlock",
},
{
"description": "Process the input",
"action": "process",
"block_name": "TextFormatterBlock",
},
{
"description": "Return output to user",
"action": "output",
"block_name": "AgentOutputBlock",
},
],
}
# Block IDs from backend/blocks/io.py
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def _generate_dummy_agent_json() -> dict[str, Any]:
"""Generate a minimal valid agent JSON for testing."""
input_node_id = str(uuid.uuid4())
output_node_id = str(uuid.uuid4())
return {
"id": str(uuid.uuid4()),
"version": 1,
"is_active": True,
"name": "Dummy Test Agent",
"description": "A dummy agent generated for testing purposes",
"nodes": [
{
"id": input_node_id,
"block_id": AGENT_INPUT_BLOCK_ID,
"input_default": {
"name": "input",
"title": "Input",
"description": "Enter your input",
"placeholder_values": [],
},
"metadata": {"position": {"x": 0, "y": 0}},
},
{
"id": output_node_id,
"block_id": AGENT_OUTPUT_BLOCK_ID,
"input_default": {
"name": "output",
"title": "Output",
"description": "Agent output",
"format": "{output}",
},
"metadata": {"position": {"x": 400, "y": 0}},
},
],
"links": [
{
"id": str(uuid.uuid4()),
"source_id": input_node_id,
"sink_id": output_node_id,
"source_name": "result",
"sink_name": "value",
"is_static": False,
},
],
}
async def decompose_goal_dummy(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Return dummy decomposition result."""
logger.info("Using dummy agent generator for decompose_goal")
return DUMMY_DECOMPOSITION_RESULT.copy()
async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
async def generate_agent_patch_dummy(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"
)
return patched
async def customize_template_dummy(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any]:
"""Return dummy customized template (returns template with updated description)."""
logger.info("Using dummy agent generator for customize_template")
customized = template_agent.copy()
customized["description"] = (
f"{template_agent.get('description', '')} (customized: {modification_request})"
)
return customized
async def get_blocks_dummy() -> list[dict[str, Any]]:
"""Return dummy blocks list."""
logger.info("Using dummy agent generator for get_blocks")
return [
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
]
async def health_check_dummy() -> bool:
"""Always returns healthy for dummy service."""
return True

View File

@@ -1,67 +0,0 @@
"""Shared helpers for agent generation."""
import re
import uuid
from typing import Any
from .blocks import get_blocks_as_dicts
__all__ = [
"AGENT_EXECUTOR_BLOCK_ID",
"AGENT_INPUT_BLOCK_ID",
"AGENT_OUTPUT_BLOCK_ID",
"AgentDict",
"MCP_TOOL_BLOCK_ID",
"UUID_REGEX",
"are_types_compatible",
"generate_uuid",
"get_blocks_as_dicts",
"get_defined_property_type",
"is_uuid",
]
# Type alias for the agent JSON structure passed through
# the validation and fixing pipeline.
AgentDict = dict[str, Any]
# Shared base pattern (unanchored, lowercase hex); used for both full-string
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
def is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
return isinstance(value, str) and UUID_REGEX.match(value) is not None
def generate_uuid() -> str:
"""Generate a new UUID string."""
return str(uuid.uuid4())
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
"""Get property type from a schema, handling nested `_#_` notation."""
if "_#_" in name:
parent, child = name.split("_#_", 1)
parent_schema = schema.get(parent, {})
if "properties" in parent_schema and isinstance(
parent_schema["properties"], dict
):
return parent_schema["properties"].get(child, {}).get("type")
return None
return schema.get(name, {}).get("type")
def are_types_compatible(src: str, sink: str) -> bool:
"""Check if two schema types are compatible."""
if {src, sink} <= {"integer", "number"}:
return True
return src == sink

View File

@@ -1,196 +0,0 @@
"""Shared fix → validate → preview/save pipeline for agent tools."""
import json
import logging
from typing import Any, cast
from backend.copilot.tools.models import (
AgentPreviewResponse,
AgentSavedResponse,
ErrorResponse,
ToolResponseBase,
)
from .blocks import get_blocks_as_dicts
from .core import get_library_agents_by_ids, save_agent_to_library
from .fixer import AgentFixer
from .validator import AgentValidator
logger = logging.getLogger(__name__)
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
async def fetch_library_agents(
user_id: str | None,
library_agent_ids: list[str],
) -> list[dict[str, Any]] | None:
"""Fetch library agents by IDs for AgentExecutorBlock validation.
Returns None if no IDs provided or user is not authenticated.
"""
if not user_id or not library_agent_ids:
return None
try:
agents = await get_library_agents_by_ids(
user_id=user_id,
agent_ids=library_agent_ids,
)
return cast(list[dict[str, Any]], agents)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
return None
async def fix_validate_and_save(
agent_json: dict[str, Any],
*,
user_id: str | None,
session_id: str | None,
save: bool = True,
is_update: bool = False,
default_name: str = "Agent",
preview_message: str | None = None,
save_message: str | None = None,
library_agents: list[dict[str, Any]] | None = None,
folder_id: str | None = None,
) -> ToolResponseBase:
"""Shared pipeline: auto-fix → validate → preview or save.
Args:
agent_json: The agent JSON dict (must already have id/version/is_active set).
user_id: The authenticated user's ID.
session_id: The chat session ID.
save: Whether to save or just preview.
is_update: Whether this is an update to an existing agent.
default_name: Fallback name if agent_json has none.
preview_message: Custom preview message (optional).
save_message: Custom save success message (optional).
library_agents: Library agents for AgentExecutorBlock validation/fixing.
Returns:
An appropriate ToolResponseBase subclass.
"""
# Size guard
json_size = len(json.dumps(agent_json))
if json_size > MAX_AGENT_JSON_SIZE:
return ErrorResponse(
message=(
f"Agent JSON is too large ({json_size:,} bytes, "
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
),
error="agent_json_too_large",
session_id=session_id,
)
blocks = get_blocks_as_dicts()
# Auto-fix
try:
fixer = AgentFixer()
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
fixes = fixer.get_fixes_applied()
if fixes:
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
except Exception as e:
logger.warning(f"Auto-fix failed: {e}")
# Validate
try:
validator = AgentValidator()
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
if not is_valid:
errors = validator.errors
return ErrorResponse(
message=(
f"The agent has {len(errors)} validation error(s):\n"
+ "\n".join(f"- {e}" for e in errors[:5])
),
error="validation_failed",
details={"errors": errors},
session_id=session_id,
)
except Exception as e:
logger.error(f"Validation failed with exception: {e}", exc_info=True)
return ErrorResponse(
message="Failed to validate the agent. Please try again.",
error="validation_exception",
details={"exception": str(e)},
session_id=session_id,
)
agent_name = agent_json.get("name", default_name)
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Build a warning suffix when name/description is missing or generic
_GENERIC_NAMES = {
"agent",
"generated agent",
"customized agent",
"updated agent",
"new agent",
"my agent",
}
metadata_warnings: list[str] = []
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
metadata_warnings.append("'name'")
if not agent_description:
metadata_warnings.append("'description'")
metadata_hint = ""
if metadata_warnings:
missing = " and ".join(metadata_warnings)
metadata_hint = (
f" Note: the agent is missing a meaningful {missing}. "
f"Please update the agent_json to include them."
)
if not save:
return AgentPreviewResponse(
message=(
(
preview_message
or f"Agent '{agent_name}' with {node_count} blocks is ready."
)
+ metadata_hint
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update, folder_id=folder_id
)
return AgentSavedResponse(
message=(
(save_message or f"Agent '{created_graph.name}' has been saved!")
+ metadata_hint
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to save agent: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -0,0 +1,511 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
from .dummy import (
customize_template_dummy,
decompose_goal_dummy,
generate_agent_dummy,
generate_agent_patch_dummy,
get_blocks_dummy,
health_check_dummy,
)
logger = logging.getLogger(__name__)
_dummy_mode_warned = False
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def _is_dummy_mode() -> bool:
"""Check if dummy mode is enabled for testing."""
global _dummy_mode_warned
settings = _get_settings()
is_dummy = bool(settings.config.agentgenerator_use_dummy)
if is_dummy and not _dummy_mode_warned:
logger.warning(
"Agent Generator running in DUMMY MODE - returning mock responses. "
"Do not use in production!"
)
_dummy_mode_warned = True
return is_dummy
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured (or dummy mode)."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host) or bool(
settings.config.agentgenerator_use_dummy
)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
if _is_dummy_mode():
return await decompose_goal_dummy(description, context, library_agents)
client = _get_client()
if context:
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/generate-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
)
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/update-agent", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
if _is_dummy_mode():
return await customize_template_dummy(
template_agent, modification_request, context
)
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
if _is_dummy_mode():
return await get_blocks_dummy()
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
if _is_dummy_mode():
return await health_check_dummy()
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -1,17 +0,0 @@
"""Agent generation validation — re-exports from split modules.
This module was split into:
- helpers.py: get_blocks_as_dicts, block cache
- fixer.py: AgentFixer class
- validator.py: AgentValidator class
"""
from .fixer import AgentFixer
from .helpers import get_blocks_as_dicts
from .validator import AgentValidator
__all__ = [
"AgentFixer",
"AgentValidator",
"get_blocks_as_dicts",
]

Some files were not shown because too many files have changed in this diff Show More