mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
92 Commits
autogpt-pl
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8892bcd230 | ||
|
|
48ff8300a4 | ||
|
|
c268fc6464 | ||
|
|
aff3fb44af | ||
|
|
9a41312769 | ||
|
|
048fb06b0a | ||
|
|
3f653e6614 | ||
|
|
c9c3d54b2b | ||
|
|
53d58e21d3 | ||
|
|
fa04fb41d8 | ||
|
|
d9c16ded65 | ||
|
|
6dc8429ae7 | ||
|
|
cfe22e5a8f | ||
|
|
0b594a219c | ||
|
|
a8259ca935 | ||
|
|
1f1288d623 | ||
|
|
02645732b8 | ||
|
|
ba301a3912 | ||
|
|
0cd9c0d87a | ||
|
|
a083493aa2 | ||
|
|
c51dc7ad99 | ||
|
|
bc6b82218a | ||
|
|
83e49f71cd | ||
|
|
ef446e4fe9 | ||
|
|
7b1e8ed786 | ||
|
|
7ccfff1040 | ||
|
|
81c7685a82 | ||
|
|
3595c6e769 | ||
|
|
1c2953d61b | ||
|
|
755bc84b1a | ||
|
|
ade2baa58f | ||
|
|
4d35534a89 | ||
|
|
2cc748f34c | ||
|
|
c2e79fa5e1 | ||
|
|
89a5b3178a | ||
|
|
c62d9a24ff | ||
|
|
0e0bfaac29 | ||
|
|
0633475915 | ||
|
|
34a2f9a0a2 | ||
|
|
9f4caa7dfc | ||
|
|
0876d22e22 | ||
|
|
15e3980d65 | ||
|
|
fe9eb2564b | ||
|
|
5641cdd3ca | ||
|
|
bfb843a56e | ||
|
|
684845d946 | ||
|
|
6a6b23c2e1 | ||
|
|
d0a1d72e8a | ||
|
|
f1945d6a2f | ||
|
|
6491cb1e23 | ||
|
|
c7124a5240 | ||
|
|
5537cb2858 | ||
|
|
aef5f6d666 | ||
|
|
8063391d0a | ||
|
|
0bbb12d688 | ||
|
|
eadc68f2a5 | ||
|
|
19d775c435 | ||
|
|
eca7b5e793 | ||
|
|
c304a4937a | ||
|
|
7ead4c040f | ||
|
|
8cfabcf4fd | ||
|
|
7bf407b66c | ||
|
|
0f813f1bf9 | ||
|
|
aa08063939 | ||
|
|
bde6a4c0df | ||
|
|
d56452898a | ||
|
|
7507240177 | ||
|
|
d7c3f5b8fc | ||
|
|
3e108a813a | ||
|
|
08c49a78f8 | ||
|
|
5d56548e6b | ||
|
|
6ecf55d214 | ||
|
|
7c8c7bf395 | ||
|
|
0b9e0665dd | ||
|
|
be18436e8f | ||
|
|
f6f268a1f0 | ||
|
|
ea0333c1fc | ||
|
|
21c705af6e | ||
|
|
a576be9db2 | ||
|
|
5e90585f10 | ||
|
|
3e22a0e786 | ||
|
|
6abe39b33a | ||
|
|
476cf1c601 | ||
|
|
25022f2d1e | ||
|
|
ce1675cfc7 | ||
|
|
3d0ede9f34 | ||
|
|
5474f7c495 | ||
|
|
f1b771b7ee | ||
|
|
aa7a2f0a48 | ||
|
|
3722d05b9b | ||
|
|
592830ce9b | ||
|
|
6cc680f71c |
79
.claude/skills/pr-address/SKILL.md
Normal file
79
.claude/skills/pr-address/SKILL.md
Normal file
@@ -0,0 +1,79 @@
|
||||
---
|
||||
name: pr-address
|
||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Address
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Fetch comments (all sources)
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||
```
|
||||
|
||||
**Bots to watch for:**
|
||||
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
||||
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
||||
- `coderabbitai[bot]` — automated review. Address actionable items.
|
||||
|
||||
## For each unaddressed comment
|
||||
|
||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||
|
||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||
2. Commit and push the fix
|
||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||
|
||||
| Comment type | How to reply |
|
||||
|---|---|
|
||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||
|
||||
## Format and commit
|
||||
|
||||
After fixing, format the changed code:
|
||||
|
||||
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
||||
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
||||
|
||||
If API routes changed, regenerate the frontend client:
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
trap "kill $REST_PID 2>/dev/null" EXIT
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID 2>/dev/null; trap - EXIT
|
||||
```
|
||||
Never manually edit files in `src/app/api/__generated__/`.
|
||||
|
||||
Then commit and **push immediately** — never batch commits without pushing.
|
||||
|
||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||
|
||||
## The loop
|
||||
|
||||
```text
|
||||
address comments → format → commit → push
|
||||
→ re-check comments → fix new ones → push
|
||||
→ wait for CI → re-check comments after CI settles
|
||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||
```
|
||||
|
||||
While CI runs, stay productive: run local tests, address remaining comments.
|
||||
|
||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||
74
.claude/skills/pr-review/SKILL.md
Normal file
74
.claude/skills/pr-review/SKILL.md
Normal file
@@ -0,0 +1,74 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
||||
user-invocable: true
|
||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||
gh pr view {N}
|
||||
```
|
||||
|
||||
## Read the diff
|
||||
|
||||
```bash
|
||||
gh pr diff {N}
|
||||
```
|
||||
|
||||
## Fetch existing review comments
|
||||
|
||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||
|
||||
```bash
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||
```
|
||||
|
||||
## What to check
|
||||
|
||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||
|
||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||
|
||||
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
||||
|
||||
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
||||
|
||||
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
||||
|
||||
## Output format
|
||||
|
||||
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
||||
|
||||
| Tier | Badge | Meaning |
|
||||
|---|---|---|
|
||||
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
||||
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
||||
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
||||
| Nit | `🔵 **Nit**` | Style / wording |
|
||||
|
||||
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
||||
|
||||
## Post inline comments
|
||||
|
||||
For each finding, post an inline comment on the PR (do not just write a local report):
|
||||
|
||||
```bash
|
||||
# Get the latest commit SHA for the PR
|
||||
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
||||
|
||||
# Post an inline comment on a specific file/line
|
||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
||||
-f body="🤖 🔴 **Blocker**: <description>" \
|
||||
-f commit_id="$COMMIT_SHA" \
|
||||
-f path="<file path>" \
|
||||
-F line=<line number>
|
||||
```
|
||||
85
.claude/skills/worktree/SKILL.md
Normal file
85
.claude/skills/worktree/SKILL.md
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
name: worktree
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "3.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Create the worktree
|
||||
|
||||
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
PARENT=$(dirname "$ROOT")
|
||||
|
||||
# From an existing branch
|
||||
git worktree add "$PARENT/<NAME>" <branch-name>
|
||||
|
||||
# From a new branch off dev
|
||||
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
||||
```
|
||||
|
||||
## Copy environment files
|
||||
|
||||
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
||||
|
||||
```bash
|
||||
ROOT=$(git rev-parse --show-toplevel)
|
||||
TARGET="$(dirname "$ROOT")/<NAME>"
|
||||
|
||||
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
||||
if [ -f "$ROOT/$envpath/.env" ]; then
|
||||
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
||||
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
||||
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||
fi
|
||||
done
|
||||
```
|
||||
|
||||
## Install dependencies
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
||||
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
||||
```
|
||||
|
||||
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
||||
|
||||
## Running the app (optional)
|
||||
|
||||
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
||||
|
||||
```bash
|
||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot testing
|
||||
|
||||
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
|
||||
## Cleanup
|
||||
|
||||
```bash
|
||||
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
||||
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||
```
|
||||
|
||||
## Alternative: Branchlet (optional)
|
||||
|
||||
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
||||
|
||||
```bash
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
```
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -5,12 +5,14 @@ on:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
merge_group:
|
||||
|
||||
169
.github/workflows/platform-frontend-ci.yml
vendored
169
.github/workflows/platform-frontend-ci.yml
vendored
@@ -120,175 +120,6 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
312
.github/workflows/platform-fullstack-ci.yml
vendored
312
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,14 +1,18 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
name: AutoGPT Platform - Full-stack CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
@@ -24,42 +28,28 @@ defaults:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies to populate cache
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: big-boi
|
||||
check-api-types:
|
||||
name: check API types
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -67,70 +57,256 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
# ------------------------ Backend setup ------------------------
|
||||
|
||||
- name: Set up Backend - Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Backend - Install Poetry
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Installing Poetry version ${POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
|
||||
- name: Set up Backend - Set up dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Set up Backend - Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Set up Backend - Generate Prisma client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Set up Frontend - Export OpenAPI schema from Backend
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
|
||||
# ------------------------ Frontend setup ------------------------
|
||||
|
||||
- name: Set up Frontend - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Frontend - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Set up Frontend - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
- name: Set up Frontend - Format OpenAPI schema
|
||||
id: format-schema
|
||||
run: pnpm prettier --write ./src/app/api/openapi.json
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "The openapi.json file has been modified after exporting the API schema."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo "\nIn the backend directory:"
|
||||
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
|
||||
echo "\nIn the frontend directory:"
|
||||
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
|
||||
echo "3. Run 'pnpm generate:api'"
|
||||
echo "4. Run 'pnpm types'"
|
||||
echo "5. Fix any TypeScript errors that may have been introduced"
|
||||
echo "6. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
- name: Set up Frontend - Generate API client
|
||||
id: generate-api-client
|
||||
run: pnpm orval --config ./orval.config.ts
|
||||
# Continue with type generation & check even if there are schema changes
|
||||
if: success() || (steps.format-schema.outcome == 'success')
|
||||
|
||||
- name: Check for TypeScript errors
|
||||
run: pnpm types
|
||||
if: success() || (steps.generate-api-client.outcome == 'success')
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
@@ -60,9 +60,12 @@ AutoGPT Platform is a monorepo containing:
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||
|
||||
When fetching comments manually:
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
|
||||
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
@@ -0,0 +1,40 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.auth_activities
|
||||
-- Looker source alias: ds49 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Tracks authentication events (login, logout, SSO, password
|
||||
-- reset, etc.) from Supabase's internal audit log.
|
||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.audit_log_entries — Supabase internal auth event log
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||
-- actor_id TEXT User ID who triggered the event
|
||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days from current date
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily login counts
|
||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||
-- FROM analytics.auth_activities
|
||||
-- WHERE action = 'login'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- SSO vs password login breakdown
|
||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||
-- WHERE action = 'login' GROUP BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
created_at,
|
||||
payload->>'actor_id' AS actor_id,
|
||||
payload->>'actor_via_sso' AS actor_via_sso,
|
||||
payload->>'action' AS action
|
||||
FROM auth.audit_log_entries
|
||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
@@ -0,0 +1,105 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.graph_execution
|
||||
-- Looker source alias: ds16 | Charts: 21
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per agent graph execution (last 90 days).
|
||||
-- Unpacks the JSONB stats column into individual numeric columns
|
||||
-- and normalises the executionStatus — runs that failed due to
|
||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||
-- to allow safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Execution UUID
|
||||
-- agentGraphId TEXT Agent graph UUID
|
||||
-- agentGraphVersion INT Graph version number
|
||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||
-- updatedAt TIMESTAMPTZ Last status update time
|
||||
-- userId TEXT Owner user UUID
|
||||
-- agentGraphName TEXT Human-readable agent name
|
||||
-- cputime DECIMAL Total CPU seconds consumed
|
||||
-- walltime DECIMAL Total wall-clock seconds
|
||||
-- node_count DECIMAL Number of nodes in the graph
|
||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||
-- execution_cost DECIMAL Credit cost of this execution
|
||||
-- correctness_score FLOAT AI correctness score (if available)
|
||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily execution counts by status
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||
-- FROM analytics.graph_execution
|
||||
-- GROUP BY 1, 2 ORDER BY 1;
|
||||
--
|
||||
-- -- Average cost per execution by agent
|
||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'COMPLETED'
|
||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||
--
|
||||
-- -- Top error messages
|
||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ge."id" AS id,
|
||||
ge."agentGraphId" AS agentGraphId,
|
||||
ge."agentGraphVersion" AS agentGraphVersion,
|
||||
CASE
|
||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||
AND (
|
||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||
)
|
||||
THEN 'NO_CREDITS'
|
||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||
END AS executionStatus,
|
||||
ge."createdAt" AS createdAt,
|
||||
ge."updatedAt" AS updatedAt,
|
||||
ge."userId" AS userId,
|
||||
g."name" AS agentGraphName,
|
||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage
|
||||
FROM platform."AgentGraphExecution" ge
|
||||
LEFT JOIN platform."AgentGraph" g
|
||||
ON ge."agentGraphId" = g."id"
|
||||
AND ge."agentGraphVersion" = g."version"
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||
"userId", "agentGraphId",
|
||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||
FROM platform."LibraryAgent"
|
||||
WHERE "isDeleted" = FALSE
|
||||
AND "isArchived" = FALSE
|
||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
@@ -0,0 +1,101 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.node_block_execution
|
||||
-- Looker source alias: ds14 | Charts: 11
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per node (block) execution (last 90 days).
|
||||
-- Unpacks stats JSONB and joins to identify which block type
|
||||
-- was run. For failed nodes, joins the error output and
|
||||
-- scrubs it for safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentNodeExecution — Node execution records
|
||||
-- platform.AgentNode — Node → block mapping
|
||||
-- platform.AgentBlock — Block name/ID
|
||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Node execution UUID
|
||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||
-- agentNodeId TEXT Node UUID within the graph
|
||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||
-- addedTime TIMESTAMPTZ When the node was queued
|
||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||
-- startedTime TIMESTAMPTZ When execution started
|
||||
-- endedTime TIMESTAMPTZ When execution finished
|
||||
-- inputSize BIGINT Input payload size in bytes
|
||||
-- outputSize BIGINT Output payload size in bytes
|
||||
-- walltime NUMERIC Wall-clock seconds for this node
|
||||
-- cputime NUMERIC CPU seconds for this node
|
||||
-- llmRetryCount INT Number of LLM retries
|
||||
-- llmCallCount INT Number of LLM API calls made
|
||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||
-- outputTokenCount BIGINT LLM output tokens produced
|
||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||
-- blockId TEXT Block UUID
|
||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Most-used blocks by execution count
|
||||
-- SELECT "blockName", COUNT(*) AS executions,
|
||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||
-- FROM analytics.node_block_execution
|
||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||
--
|
||||
-- -- Average LLM token usage per block
|
||||
-- SELECT "blockName",
|
||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "llmCallCount" > 0
|
||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||
--
|
||||
-- -- Top failure reasons
|
||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ne."id" AS id,
|
||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||
ne."agentNodeId" AS agentNodeId,
|
||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||
ne."addedTime" AS addedTime,
|
||||
ne."queuedTime" AS queuedTime,
|
||||
ne."startedTime" AS startedTime,
|
||||
ne."endedTime" AS endedTime,
|
||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||
b."name" AS blockName,
|
||||
b."id" AS blockId,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM eio."data"::text),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage,
|
||||
eio."data" AS errorMessage
|
||||
FROM platform."AgentNodeExecution" ne
|
||||
LEFT JOIN platform."AgentNode" nd
|
||||
ON ne."agentNodeId" = nd."id"
|
||||
LEFT JOIN platform."AgentBlock" b
|
||||
ON nd."agentBlockId" = b."id"
|
||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||
ON eio."referencedByOutputExecId" = ne."id"
|
||||
AND eio."name" = 'error'
|
||||
AND ne."executionStatus" = 'FAILED'
|
||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
@@ -0,0 +1,97 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_agent
|
||||
-- Looker source alias: ds35 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention broken down per individual agent.
|
||||
-- Cohort = week of a user's first use of THAT specific agent.
|
||||
-- Tells you which agents keep users coming back vs. one-shot
|
||||
-- use. Only includes cohorts from the last 180 days.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||
-- platform.AgentGraph — Agent names
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- agent_id TEXT Agent graph UUID
|
||||
-- agent_label TEXT 'AgentName [first8chars]'
|
||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||
-- cohort_week_start DATE Week users first ran this agent
|
||||
-- cohort_label TEXT ISO week label
|
||||
-- cohort_label_n TEXT ISO week label with cohort size
|
||||
-- user_lifetime_week INT Weeks since first use of this agent
|
||||
-- cohort_users BIGINT Users in this cohort for this agent
|
||||
-- active_users BIGINT Users who ran the agent again in week k
|
||||
-- retention_rate FLOAT active_users / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Best-retained agents at week 2
|
||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||
-- FROM analytics.retention_agent
|
||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||
--
|
||||
-- -- Agents with most unique users
|
||||
-- SELECT DISTINCT agent_label, agent_total_users
|
||||
-- FROM analytics.retention_agent
|
||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||
e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e
|
||||
),
|
||||
first_use AS (
|
||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1,2
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||
),
|
||||
active_counts AS (
|
||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||
),
|
||||
cohort_sizes AS (
|
||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||
),
|
||||
cohort_caps AS (
|
||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
),
|
||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||
SELECT
|
||||
g.agent_id,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(ac.active_users,0) AS active_users,
|
||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||
FROM grid g
|
||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -0,0 +1,81 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_daily
|
||||
-- Looker source alias: ds111 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on agent executions.
|
||||
-- Cohort anchor = day of user's FIRST ever execution.
|
||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||
-- Great for early engagement analysis (did users run another
|
||||
-- agent the next day?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_daily.
|
||||
-- cohort_day_start = day of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-3 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||
-- FROM analytics.retention_execution_daily
|
||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fe.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -0,0 +1,81 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_weekly
|
||||
-- Looker source alias: ds92 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on agent executions.
|
||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||
-- (not first login). Only includes cohorts from the last 180 days.
|
||||
-- Useful when you care about product engagement, not just visits.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_weekly.
|
||||
-- cohort_week_start = week of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-2 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded
|
||||
-- FROM analytics.retention_execution_weekly
|
||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fe.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
@@ -0,0 +1,94 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_daily
|
||||
-- Looker source alias: ds112 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on login sessions.
|
||||
-- Same logic as retention_login_weekly but at day granularity,
|
||||
-- showing up to day 30 for cohorts from the last 90 days.
|
||||
-- Useful for analysing early activation (days 1-7) in detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||
-- cohort_day_start DATE First day the cohort logged in
|
||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||
-- cohort_users BIGINT Total users in cohort
|
||||
-- active_users_bounded BIGINT Users active on exactly day k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-1 retention rate (came back next day)
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||
--
|
||||
-- -- Average retention curve across all cohorts
|
||||
-- SELECT user_lifetime_day,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fl.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -0,0 +1,96 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_onboarded_weekly
|
||||
-- Looker source alias: ds101 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention from login sessions, restricted to
|
||||
-- users who "onboarded" — defined as running at least one
|
||||
-- agent within 365 days of their first login.
|
||||
-- Filters out users who signed up but never activated,
|
||||
-- giving a cleaner view of engaged-user retention.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||
-- Only difference: cohort is filtered to onboarded users only.
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Compare week-4 retention: all users vs onboarded only
|
||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||
-- UNION ALL
|
||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login_all AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
onboarders AS (
|
||||
SELECT fl.user_id FROM first_login_all fl
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||
WHERE e."userId"::text = fl.user_id
|
||||
AND e."createdAt" >= fl.first_login_time
|
||||
AND e."createdAt" < fl.first_login_time
|
||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||
)
|
||||
),
|
||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
@@ -0,0 +1,103 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_weekly
|
||||
-- Looker source alias: ds83 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on login sessions.
|
||||
-- Users are grouped by the ISO week of their first ever login.
|
||||
-- For each cohort × lifetime-week combination, outputs both:
|
||||
-- - bounded rate: % active in exactly that week
|
||||
-- - unbounded rate: % who were ever active on or after that week
|
||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- HOW TO READ THE OUTPUT
|
||||
-- cohort_week_start The Monday of the week users first logged in
|
||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- cohort_week_start DATE First day of the cohort's signup week
|
||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||
-- active_users_bounded BIGINT Users active in exactly week k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-1 retention rate per cohort
|
||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- WHERE user_lifetime_week = 1
|
||||
-- ORDER BY cohort_week_start;
|
||||
--
|
||||
-- -- Overall average retention curve (all cohorts combined)
|
||||
-- SELECT user_lifetime_week,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
@@ -0,0 +1,71 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_block_spending
|
||||
-- Looker source alias: ds6 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per credit transaction (last 90 days).
|
||||
-- Shows how users spend credits broken down by block type,
|
||||
-- LLM provider and model. Joins node execution stats for
|
||||
-- token-level detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.CreditTransaction — Credit debit/credit records
|
||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- transactionKey TEXT Unique transaction identifier
|
||||
-- userId TEXT User who was charged
|
||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||
-- blockId TEXT Block UUID that triggered the spend
|
||||
-- blockName TEXT Human-readable block name
|
||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||
-- node_exec_id TEXT Linked node execution UUID
|
||||
-- llm_call_count INT LLM API calls made in that execution
|
||||
-- llm_retry_count INT LLM retries in that execution
|
||||
-- llm_input_token_count INT Input tokens consumed
|
||||
-- llm_output_token_count INT Output tokens produced
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend per user (last 90 days)
|
||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "transactionType" = 'USAGE'
|
||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||
--
|
||||
-- -- Spend by LLM provider + model
|
||||
-- SELECT "llm_provider", "llm_model",
|
||||
-- SUM("negativeAmount") AS total_cost,
|
||||
-- SUM("llm_input_token_count") AS input_tokens,
|
||||
-- SUM("llm_output_token_count") AS output_tokens
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "llm_provider" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
c."transactionKey" AS transactionKey,
|
||||
c."userId" AS userId,
|
||||
c."amount" AS amount,
|
||||
c."amount" * -1 AS negativeAmount,
|
||||
c."type" AS transactionType,
|
||||
c."createdAt" AS transactionTime,
|
||||
c.metadata->>'block_id' AS blockId,
|
||||
c.metadata->>'block' AS blockName,
|
||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||
c.metadata->'input'->>'model' AS llm_model,
|
||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||
FROM platform."CreditTransaction" c
|
||||
LEFT JOIN platform."AgentNodeExecution" ne
|
||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
@@ -0,0 +1,45 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding
|
||||
-- Looker source alias: ds68 | Charts: 3
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user onboarding record. Contains the user's
|
||||
-- stated usage reason, selected integrations, completed
|
||||
-- onboarding steps and optional first agent selection.
|
||||
-- Full history (no date filter) since onboarding happens
|
||||
-- once per user.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding state per user
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Onboarding record UUID
|
||||
-- createdAt TIMESTAMPTZ When onboarding started
|
||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||
-- integrations TEXT[] Array of integration names the user selected
|
||||
-- userId TEXT User UUID
|
||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Usage reason breakdown
|
||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||
--
|
||||
-- -- Completion rate per step
|
||||
-- SELECT step, COUNT(*) AS users_completed
|
||||
-- FROM analytics.user_onboarding
|
||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"usageReason",
|
||||
integrations,
|
||||
"userId",
|
||||
"completedSteps",
|
||||
"selectedStoreListingVersionId"
|
||||
FROM platform."UserOnboarding"
|
||||
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
@@ -0,0 +1,100 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_funnel
|
||||
-- Looker source alias: ds74 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated onboarding funnel showing how many users
|
||||
-- completed each step and the drop-off percentage from the
|
||||
-- previous step. One row per onboarding step (all 22 steps
|
||||
-- always present, even with 0 completions — prevents sparse
|
||||
-- gaps from making LAG compare the wrong predecessors).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||
-- users_completed BIGINT Distinct users who completed this step
|
||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||
--
|
||||
-- STEP ORDER
|
||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full funnel
|
||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||
--
|
||||
-- -- Biggest drop-off point
|
||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||
-- =============================================================
|
||||
|
||||
WITH all_steps AS (
|
||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||
-- are always present, keeping LAG comparisons correct.
|
||||
SELECT step_name, step_order
|
||||
FROM (VALUES
|
||||
('WELCOME', 1),
|
||||
('USAGE_REASON', 2),
|
||||
('INTEGRATIONS', 3),
|
||||
('AGENT_CHOICE', 4),
|
||||
('AGENT_NEW_RUN', 5),
|
||||
('AGENT_INPUT', 6),
|
||||
('CONGRATS', 7),
|
||||
('GET_RESULTS', 8),
|
||||
('MARKETPLACE_VISIT', 9),
|
||||
('MARKETPLACE_ADD_AGENT', 10),
|
||||
('MARKETPLACE_RUN_AGENT', 11),
|
||||
('BUILDER_OPEN', 12),
|
||||
('BUILDER_SAVE_AGENT', 13),
|
||||
('BUILDER_RUN_AGENT', 14),
|
||||
('VISIT_COPILOT', 15),
|
||||
('RE_RUN_AGENT', 16),
|
||||
('SCHEDULE_AGENT', 17),
|
||||
('RUN_AGENTS', 18),
|
||||
('RUN_3_DAYS', 19),
|
||||
('TRIGGER_WEBHOOK', 20),
|
||||
('RUN_14_DAYS', 21),
|
||||
('RUN_AGENTS_100', 22)
|
||||
) AS t(step_name, step_order)
|
||||
),
|
||||
raw AS (
|
||||
SELECT
|
||||
u."userId",
|
||||
step_txt::text AS step
|
||||
FROM platform."UserOnboarding" u
|
||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
),
|
||||
step_counts AS (
|
||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||
FROM raw GROUP BY step
|
||||
),
|
||||
funnel AS (
|
||||
SELECT
|
||||
a.step_name AS step,
|
||||
a.step_order,
|
||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||
ROUND(
|
||||
100.0 * COALESCE(sc.users_completed, 0)
|
||||
/ NULLIF(
|
||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||
0
|
||||
),
|
||||
2
|
||||
) AS pct_from_prev
|
||||
FROM all_steps a
|
||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||
)
|
||||
SELECT * FROM funnel ORDER BY step_order
|
||||
@@ -0,0 +1,41 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_integration
|
||||
-- Looker source alias: ds75 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated count of users who selected each integration
|
||||
-- during onboarding. One row per integration type, sorted
|
||||
-- by popularity.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — integrations array column
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full integration popularity ranking
|
||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||
--
|
||||
-- -- Top 5 integrations
|
||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||
-- =============================================================
|
||||
|
||||
WITH exploded AS (
|
||||
SELECT
|
||||
u."userId" AS user_id,
|
||||
UNNEST(u."integrations") AS integration
|
||||
FROM platform."UserOnboarding" u
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
)
|
||||
SELECT
|
||||
integration,
|
||||
COUNT(DISTINCT user_id) AS users_with_integration
|
||||
FROM exploded
|
||||
WHERE integration IS NOT NULL AND integration <> ''
|
||||
GROUP BY integration
|
||||
ORDER BY users_with_integration DESC
|
||||
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
@@ -0,0 +1,145 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.users_activities
|
||||
-- Looker source alias: ds56 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user with lifetime activity summary.
|
||||
-- Joins login sessions with agent graphs, executions and
|
||||
-- node-level runs to give a full picture of how engaged
|
||||
-- each user is. Includes a convenience flag for 7-day
|
||||
-- activation (did the user return at least 7 days after
|
||||
-- their first login?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login/session records
|
||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||
-- platform.AgentGraphExecution — Agent run history
|
||||
-- platform.AgentNodeExecution — Individual block execution history
|
||||
--
|
||||
-- PERFORMANCE NOTE
|
||||
-- Each CTE aggregates its own table independently by userId.
|
||||
-- This avoids the fan-out that occurs when driving every join
|
||||
-- from user_logins across the two largest tables
|
||||
-- (AgentGraphExecution and AgentNodeExecution).
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- user_id TEXT Supabase user UUID
|
||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||
-- node_execution_count BIGINT Total node executions across all runs
|
||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Users who ran at least one agent and returned after 7 days
|
||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||
--
|
||||
-- -- Top 10 most active users by agent runs
|
||||
-- SELECT user_id, agent_runs, node_execution_count
|
||||
-- FROM analytics.users_activities
|
||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||
--
|
||||
-- -- 7-day activation rate
|
||||
-- SELECT
|
||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||
-- AS activation_rate
|
||||
-- FROM analytics.users_activities;
|
||||
-- =============================================================
|
||||
|
||||
WITH user_logins AS (
|
||||
SELECT
|
||||
user_id::text AS user_id,
|
||||
MIN(created_at) AS first_login_time,
|
||||
MAX(created_at) AS last_login_time,
|
||||
GREATEST(
|
||||
MAX(refreshed_at)::timestamptz,
|
||||
MAX(created_at)::timestamptz
|
||||
) AS last_visit_time
|
||||
FROM auth.sessions
|
||||
GROUP BY user_id
|
||||
),
|
||||
user_agents AS (
|
||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MAX("updatedAt") AS last_agent_save_time,
|
||||
COUNT(DISTINCT "id") AS agent_count
|
||||
FROM platform."AgentGraph"
|
||||
WHERE "isActive"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_graph_runs AS (
|
||||
-- Aggregate AgentGraphExecution directly by userId
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MIN("createdAt") AS first_agent_run_time,
|
||||
MAX("createdAt") AS last_agent_run_time,
|
||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||
COUNT("id") AS agent_runs
|
||||
FROM platform."AgentGraphExecution"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_node_runs AS (
|
||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||
-- single join to AgentGraphExecution instead of fanning out from
|
||||
-- user_logins through both large tables.
|
||||
SELECT
|
||||
g."userId"::text AS user_id,
|
||||
COUNT(*) AS node_execution_count,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||
FROM platform."AgentNodeExecution" n
|
||||
JOIN platform."AgentGraphExecution" g
|
||||
ON g."id" = n."agentGraphExecutionId"
|
||||
GROUP BY g."userId"
|
||||
)
|
||||
SELECT
|
||||
ul.user_id,
|
||||
ul.first_login_time,
|
||||
ul.last_login_time,
|
||||
ul.last_visit_time,
|
||||
ua.last_agent_save_time,
|
||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||
gr.first_agent_run_time,
|
||||
gr.last_agent_run_time,
|
||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||
CASE
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||
ELSE NULL
|
||||
END AS is_active_after_7d,
|
||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||
FROM user_logins ul
|
||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||
@@ -37,6 +37,10 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
|
||||
@@ -58,10 +58,31 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||
- **List comprehensions** over manual loop-and-append
|
||||
- **Early return** — guard clauses first, avoid deep nesting
|
||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||
- After refactoring, update mock targets to match new module paths
|
||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||
|
||||
## Database Schema
|
||||
|
||||
@@ -157,6 +178,16 @@ yield "image_url", result_url
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Workspace & Media Files
|
||||
|
||||
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||
- Working on CoPilot file upload/download features
|
||||
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||
- Debugging file persistence or virus scanning issues
|
||||
|
||||
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
@@ -111,13 +111,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
# Copy Node.js installation for Prisma and agent-browser.
|
||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||
# proper symlinks so npm/npx can find their modules.
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -9,9 +9,10 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.blocks
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.external.middleware import require_auth, require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
@@ -230,13 +231,13 @@ async def get_graph_execution_results(
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
@@ -278,7 +279,7 @@ async def get_store_agents(
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
@@ -306,13 +307,13 @@ async def get_store_agent(
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
@@ -348,7 +349,7 @@ async def get_store_creators(
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -14,3 +23,70 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
Paginated listings with their versions
|
||||
"""
|
||||
try:
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
) -> store_model.StoreSubmissionAdminView:
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
@@ -84,31 +73,24 @@ async def review_submission(
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmission with updated review information
|
||||
StoreSubmissionAdminView with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches whenever approval state changes, since store visibility can change
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -0,0 +1,168 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -11,7 +11,7 @@ from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
@@ -25,8 +25,10 @@ from backend.copilot.model import (
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
@@ -51,6 +53,8 @@ from backend.copilot.tools.models import (
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import get_business_understanding
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -125,6 +129,7 @@ class SessionSummaryResponse(BaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
is_processing: bool
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
@@ -141,6 +146,20 @@ class CancelSessionResponse(BaseModel):
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UpdateSessionTitleRequest(BaseModel):
|
||||
"""Request model for updating a session's title."""
|
||||
|
||||
title: str
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_must_not_be_blank(cls, v: str) -> str:
|
||||
stripped = v.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Title must not be blank")
|
||||
return stripped
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -169,6 +188,28 @@ async def list_sessions(
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
# Batch-check Redis for active stream status on each session
|
||||
processing_set: set[str] = set()
|
||||
if sessions:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
for session in sessions:
|
||||
pipe.hget(
|
||||
f"{config.session_meta_prefix}{session.session_id}",
|
||||
"status",
|
||||
)
|
||||
statuses = await pipe.execute()
|
||||
processing_set = {
|
||||
session.session_id
|
||||
for session, st in zip(sessions, statuses)
|
||||
if st == "running"
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
@@ -176,6 +217,7 @@ async def list_sessions(
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
is_processing=session.session_id in processing_set,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
@@ -250,12 +292,12 @@ async def delete_session(
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
config = ChatConfig()
|
||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
|
||||
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
||||
e2b_cfg = ChatConfig()
|
||||
if e2b_cfg.e2b_active:
|
||||
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
||||
try:
|
||||
await kill_sandbox(session_id, config.e2b_api_key)
|
||||
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
@@ -264,6 +306,43 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def update_session_title_route(
|
||||
session_id: str,
|
||||
request: UpdateSessionTitleRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Update the title of a chat session.
|
||||
|
||||
Allows the user to rename their chat session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
request: Request body containing the new title.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
dict: Status of the update.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
success = await update_session_title(session_id, user_id, request.title)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -753,7 +832,6 @@ async def resume_session_stream(
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
@@ -776,6 +854,36 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Suggested Prompts ==========
|
||||
|
||||
|
||||
class SuggestedPromptsResponse(BaseModel):
|
||||
"""Response model for user-specific suggested prompts."""
|
||||
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/suggested-prompts",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def get_suggested_prompts(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SuggestedPromptsResponse:
|
||||
"""
|
||||
Get LLM-generated suggested prompts for the authenticated user.
|
||||
|
||||
Returns personalized quick-action prompts based on the user's
|
||||
business understanding. Returns an empty list if no custom prompts
|
||||
are available.
|
||||
"""
|
||||
understanding = await get_business_understanding(user_id)
|
||||
if understanding is None:
|
||||
return SuggestedPromptsResponse(prompts=[])
|
||||
|
||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Tests for chat route file_ids validation and enrichment."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -17,6 +19,7 @@ TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
@@ -24,7 +27,95 @@ def setup_app_auth(mock_jwt_user):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---- file_ids Pydantic validation (B1) ----
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
@@ -92,7 +183,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- UUID format filtering ----
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
@@ -131,7 +222,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ---- Cross-workspace file_ids ----
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
@@ -158,3 +249,62 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.data.execution.get_node_executions"
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "test_node_123"
|
||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
||||
|
||||
# Mock get_node_executions to return node_id mapping
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.data.execution.get_node_executions"
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
||||
|
||||
# Mock get_node_executions to return batch node data
|
||||
mock_get_node_executions = mocker.patch(
|
||||
"backend.data.execution.get_node_executions"
|
||||
"backend.api.features.executions.review.routes.get_node_executions"
|
||||
)
|
||||
# Create mock node executions for each review
|
||||
mock_node_execs = []
|
||||
|
||||
@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.copilot.constants import (
|
||||
is_copilot_synthetic_id,
|
||||
parse_node_id_from_exec_id,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
get_node_executions,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
@@ -22,6 +27,7 @@ from backend.data.human_review import (
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -35,6 +41,38 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_node_ids(
|
||||
node_exec_ids: list[str],
|
||||
graph_exec_id: str,
|
||||
is_copilot: bool,
|
||||
) -> dict[str, str]:
|
||||
"""Resolve node_exec_id -> node_id for auto-approval records.
|
||||
|
||||
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
||||
Graph executions look up node_id from NodeExecution records.
|
||||
"""
|
||||
if not node_exec_ids:
|
||||
return {}
|
||||
|
||||
if is_copilot:
|
||||
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
||||
|
||||
result = {}
|
||||
for neid in node_exec_ids:
|
||||
if neid in node_exec_map:
|
||||
result[neid] = node_exec_map[neid]
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to resolve node_id for {neid}: Node execution not found."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/pending",
|
||||
summary="Get Pending Reviews",
|
||||
@@ -109,14 +147,16 @@ async def list_pending_reviews_for_execution(
|
||||
"""
|
||||
|
||||
# Verify user owns the graph execution before returning reviews
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
# (CoPilot synthetic IDs don't have graph execution records)
|
||||
if not is_copilot_synthetic_id(graph_exec_id):
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
|
||||
@@ -159,30 +199,26 @@ async def process_review_action(
|
||||
)
|
||||
|
||||
graph_exec_id = next(iter(graph_exec_ids))
|
||||
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
||||
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
||||
if not is_copilot:
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
@@ -235,7 +271,7 @@ async def process_review_action(
|
||||
)
|
||||
return (node_id, False)
|
||||
|
||||
# Collect node_exec_ids that need auto-approval
|
||||
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
||||
node_exec_ids_needing_auto_approval = [
|
||||
node_exec_id
|
||||
for node_exec_id, review_result in updated_reviews.items()
|
||||
@@ -243,29 +279,16 @@ async def process_review_action(
|
||||
and auto_approve_requests.get(node_exec_id, False)
|
||||
]
|
||||
|
||||
# Batch-fetch node executions to get node_ids
|
||||
node_id_map = await _resolve_node_ids(
|
||||
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
||||
)
|
||||
|
||||
# Deduplicate by node_id — one auto-approval per node
|
||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||
if node_exec_ids_needing_auto_approval:
|
||||
from backend.data.execution import get_node_executions
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_exec = node_exec_map.get(node_exec_id)
|
||||
if node_exec:
|
||||
review_result = updated_reviews[node_exec_id]
|
||||
# Use the first approved review for this node (deduplicate by node_id)
|
||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||
f"Node execution not found. This may indicate a race condition "
|
||||
f"or data inconsistency."
|
||||
)
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_id = node_id_map.get(node_exec_id)
|
||||
if node_id and node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
||||
|
||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||
auto_approval_results = await asyncio.gather(
|
||||
@@ -280,13 +303,11 @@ async def process_review_action(
|
||||
auto_approval_failed_count = 0
|
||||
for result in auto_approval_results:
|
||||
if isinstance(result, Exception):
|
||||
# Unexpected exception during auto-approval creation
|
||||
auto_approval_failed_count += 1
|
||||
logger.error(
|
||||
f"Unexpected exception during auto-approval creation: {result}"
|
||||
)
|
||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||
# Auto-approval creation failed (returned False)
|
||||
auto_approval_failed_count += 1
|
||||
|
||||
# Count results
|
||||
@@ -301,30 +322,31 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
# Resume graph execution only for real graph executions (not CoPilot)
|
||||
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
||||
if not is_copilot and updated_reviews:
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
|
||||
@@ -8,7 +8,6 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.api.features.store.exceptions as store_exceptions
|
||||
import backend.api.features.store.image_gen as store_image_gen
|
||||
import backend.api.features.store.media as store_media
|
||||
import backend.data.graph as graph_db
|
||||
@@ -251,7 +250,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
The requested LibraryAgent.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during retrieval.
|
||||
"""
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
@@ -398,6 +397,7 @@ async def create_library_agent(
|
||||
hitl_safe_mode: bool = True,
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
create_library_agents_for_sub_graphs: bool = True,
|
||||
folder_id: str | None = None,
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
@@ -414,12 +414,18 @@ async def create_library_agent(
|
||||
If the graph has sub-graphs, the parent graph will always be the first entry in the list.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the specified agent does not exist.
|
||||
NotFoundError: If the specified agent does not exist.
|
||||
DatabaseError: If there's an error during creation or if image generation fails.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating library agent for graph #{graph.id} v{graph.version}; user:<redacted>"
|
||||
)
|
||||
|
||||
# Authorization: FK only checks existence, not ownership.
|
||||
# Verify the folder belongs to this user to prevent cross-user nesting.
|
||||
if folder_id:
|
||||
await get_folder(folder_id, user_id)
|
||||
|
||||
graph_entries = (
|
||||
[graph, *graph.sub_graphs] if create_library_agents_for_sub_graphs else [graph]
|
||||
)
|
||||
@@ -432,7 +438,6 @@ async def create_library_agent(
|
||||
isCreatedByUser=(user_id == user_id),
|
||||
useGraphIsActiveVersion=True,
|
||||
User={"connect": {"id": user_id}},
|
||||
# Creator={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {
|
||||
@@ -448,6 +453,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
@@ -529,6 +539,7 @@ async def update_agent_version_in_library(
|
||||
async def create_graph_in_library(
|
||||
graph: graph_db.Graph,
|
||||
user_id: str,
|
||||
folder_id: str | None = None,
|
||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||
"""Create a new graph and add it to the user's library."""
|
||||
graph.version = 1
|
||||
@@ -542,6 +553,7 @@ async def create_graph_in_library(
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
if created_graph.is_active:
|
||||
@@ -817,7 +829,7 @@ async def add_store_agent_to_library(
|
||||
The newly created LibraryAgent if successfully added, the existing corresponding one if any.
|
||||
|
||||
Raises:
|
||||
AgentNotFoundError: If the store listing or associated agent is not found.
|
||||
NotFoundError: If the store listing or associated agent is not found.
|
||||
DatabaseError: If there's an issue creating the LibraryAgent record.
|
||||
"""
|
||||
logger.debug(
|
||||
@@ -832,7 +844,7 @@ async def add_store_agent_to_library(
|
||||
)
|
||||
if not store_listing_version or not store_listing_version.AgentGraph:
|
||||
logger.warning(f"Store listing version not found: {store_listing_version_id}")
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found or invalid"
|
||||
)
|
||||
|
||||
@@ -846,7 +858,7 @@ async def add_store_agent_to_library(
|
||||
include_subgraphs=False,
|
||||
)
|
||||
if not graph_model:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph.id} v{graph.version} not found or accessible"
|
||||
)
|
||||
|
||||
@@ -1481,6 +1493,67 @@ async def bulk_move_agents_to_folder(
|
||||
return [library_model.LibraryAgent.from_db(agent) for agent in agents]
|
||||
|
||||
|
||||
def collect_tree_ids(
|
||||
nodes: list[library_model.LibraryFolderTree],
|
||||
visited: set[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Collect all folder IDs from a folder tree."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
ids: list[str] = []
|
||||
for n in nodes:
|
||||
if n.id in visited:
|
||||
continue
|
||||
visited.add(n.id)
|
||||
ids.append(n.id)
|
||||
ids.extend(collect_tree_ids(n.children, visited))
|
||||
return ids
|
||||
|
||||
|
||||
async def get_folder_agent_summaries(
|
||||
user_id: str, folder_id: str
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of agents in a folder (id, name, description)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, folder_id=folder_id, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_root_agent_summaries(
|
||||
user_id: str,
|
||||
) -> list[dict[str, str | None]]:
|
||||
"""Get a lightweight list of root-level agents (folderId IS NULL)."""
|
||||
all_agents: list[library_model.LibraryAgent] = []
|
||||
for page in itertools.count(1):
|
||||
resp = await list_library_agents(
|
||||
user_id=user_id, include_root_only=True, page=page
|
||||
)
|
||||
all_agents.extend(resp.agents)
|
||||
if page >= resp.pagination.total_pages:
|
||||
break
|
||||
return [
|
||||
{"id": a.id, "name": a.name, "description": a.description} for a in all_agents
|
||||
]
|
||||
|
||||
|
||||
async def get_folder_agents_map(
|
||||
user_id: str, folder_ids: list[str]
|
||||
) -> dict[str, list[dict[str, str | None]]]:
|
||||
"""Get agent summaries for multiple folders concurrently."""
|
||||
results = await asyncio.gather(
|
||||
*(get_folder_agent_summaries(user_id, fid) for fid in folder_ids)
|
||||
)
|
||||
return dict(zip(folder_ids, results))
|
||||
|
||||
|
||||
##############################################
|
||||
########### Presets DB Functions #############
|
||||
##############################################
|
||||
|
||||
@@ -4,7 +4,6 @@ import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
with pytest.raises(db.NotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
|
||||
@@ -165,7 +165,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -206,7 +205,9 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
can_access_graph: bool
|
||||
can_access_graph: bool = pydantic.Field(
|
||||
description="Indicates whether the same user owns the corresponding graph"
|
||||
)
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
@@ -324,7 +325,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -67,7 +66,6 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -131,7 +129,6 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -184,7 +181,6 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -80,7 +80,7 @@ async def discover_tools(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
await validate_url_host(auth_server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
await validate_url_host(registration_endpoint)
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
@@ -429,7 +429,7 @@ async def mcp_store_token(
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
await validate_url_host(request.server_url)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ async def client():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
@@ -521,12 +521,12 @@ class TestStoreToken:
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
"""Verify that validate_url_host is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
@@ -541,7 +541,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
@@ -556,7 +556,7 @@ class TestSSRFValidation:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
"backend.api.features.mcp.routes.validate_url_host",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
@@ -23,7 +21,7 @@ def clear_all_caches():
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
return await store_db.get_store_creator(username=username.lower())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
# Mock data - StoreAgent view already contains the active version data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
use_for_onboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
# Mock StoreAgent prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - should use active version data
|
||||
# Verify results - constructed from the StoreAgent view
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "version123"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
assert result.store_listing_version_id == "version123"
|
||||
assert result.graph_id == "test-graph-id"
|
||||
assert result.runs == 10
|
||||
assert result.rating == 4.5
|
||||
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
# Verify single StoreAgent lookup
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator_details(mocker):
|
||||
async def test_get_store_creator(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator_details("creator")
|
||||
result = await db.get_store_creator("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
now = datetime.now()
|
||||
|
||||
# Mock agent graph (with no pending submissions) and user with profile
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
userId="user-id",
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test",
|
||||
isFeatured=False,
|
||||
links=[],
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
mock_user = prisma.models.User(
|
||||
id="user-id",
|
||||
email="test@example.com",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
Profile=[mock_profile],
|
||||
emailVerified=True,
|
||||
metadata="{}", # type: ignore[reportArgumentType]
|
||||
integrations="",
|
||||
maxEmailsPerDay=1,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
createdAt=now,
|
||||
isActive=True,
|
||||
StoreListingVersions=[],
|
||||
User=mock_user,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
# Mock the created StoreListingVersion (returned by create)
|
||||
mock_store_listing_obj = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
subHeading="Test heading",
|
||||
imageUrls=["image.jpg"],
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
mock_version = prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
subHeading="",
|
||||
imageUrls=[],
|
||||
categories=[],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
submittedAt=now,
|
||||
StoreListing=mock_store_listing_obj,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
graph_id="agent-id",
|
||||
graph_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
assert result.store_listing_version_id == "version-id"
|
||||
assert result.listing_version_id == "version-id"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
mock_slv.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
sorted_by=db.StoreAgentsSortOptions.RATING,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
@@ -57,12 +57,6 @@ class StoreError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa."agentGraphId",
|
||||
sa.graph_id,
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
ON c."storeListingVersionId" = sa.listing_version_id
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
ON sa.listing_version_id = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
"agentGraphId",
|
||||
graph_id,
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Self
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import prisma.models
|
||||
|
||||
|
||||
class ChangelogEntry(pydantic.BaseModel):
|
||||
version: str
|
||||
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
||||
date: datetime.datetime
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
class MyUnpublishedAgent(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
agent_name: str
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyUnpublishedAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
||||
return cls(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.graph_id,
|
||||
)
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
agentGraphVersions: list[str]
|
||||
agentGraphId: str
|
||||
graph_id: str
|
||||
graph_versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
active_version_id: str
|
||||
has_approved_version: bool
|
||||
|
||||
# Optional changelog data when include_changelog=True
|
||||
changelog: list[ChangelogEntry] | None = None
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
||||
return cls(
|
||||
store_listing_version_id=agent.listing_version_id,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
graph_id=agent.graph_id,
|
||||
graph_versions=agent.graph_versions,
|
||||
last_updated=agent.updated_at,
|
||||
recommended_schedule_cron=agent.recommended_schedule_cron,
|
||||
active_version_id=agent.listing_version_id,
|
||||
has_approved_version=True, # StoreAgent view only has approved agents
|
||||
)
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
"""Marketplace user profile (only attributes that the user can update)"""
|
||||
|
||||
username: str
|
||||
name: str
|
||||
description: str
|
||||
avatar_url: str | None
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class ProfileDetails(Profile):
|
||||
"""Marketplace user profile (including read-only fields)"""
|
||||
|
||||
is_featured: bool
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
||||
return cls(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatarUrl,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
is_featured=profile.isFeatured,
|
||||
)
|
||||
|
||||
|
||||
class CreatorDetails(ProfileDetails):
|
||||
"""Marketplace creator profile details, including aggregated stats"""
|
||||
|
||||
num_agents: int
|
||||
agent_runs: int
|
||||
agent_rating: float
|
||||
top_categories: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
||||
return cls(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
avatar_url=creator.avatar_url,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
is_featured=creator.is_featured,
|
||||
num_agents=creator.num_agents,
|
||||
agent_runs=creator.agent_runs,
|
||||
agent_rating=creator.agent_rating,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[CreatorDetails]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
# From StoreListing:
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
user_id: str
|
||||
slug: str
|
||||
|
||||
# From StoreListingVersion:
|
||||
listing_version_id: str
|
||||
listing_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
instructions: str | None
|
||||
categories: list[str]
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
version: int | None = None # Actual version number from the database
|
||||
video_url: str | None
|
||||
agent_output_demo_url: str | None
|
||||
|
||||
submitted_at: datetime.datetime | None
|
||||
changes_summary: str | None
|
||||
status: prisma.enums.SubmissionStatus
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
# Additional fields for editing
|
||||
video_url: str | None = None
|
||||
agent_output_demo_url: str | None = None
|
||||
categories: list[str] = []
|
||||
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count: int = 0
|
||||
review_count: int = 0
|
||||
review_avg_rating: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
"""Construct from the StoreSubmission Prisma view."""
|
||||
return cls(
|
||||
listing_id=_sub.listing_id,
|
||||
user_id=_sub.user_id,
|
||||
slug=_sub.slug,
|
||||
listing_version_id=_sub.listing_version_id,
|
||||
listing_version=_sub.listing_version,
|
||||
graph_id=_sub.graph_id,
|
||||
graph_version=_sub.graph_version,
|
||||
name=_sub.name,
|
||||
sub_heading=_sub.sub_heading,
|
||||
description=_sub.description,
|
||||
instructions=_sub.instructions,
|
||||
categories=_sub.categories,
|
||||
image_urls=_sub.image_urls,
|
||||
video_url=_sub.video_url,
|
||||
agent_output_demo_url=_sub.agent_output_demo_url,
|
||||
submitted_at=_sub.submitted_at,
|
||||
changes_summary=_sub.changes_summary,
|
||||
status=_sub.status,
|
||||
reviewed_at=_sub.reviewed_at,
|
||||
reviewer_id=_sub.reviewer_id,
|
||||
review_comments=_sub.review_comments,
|
||||
run_count=_sub.run_count,
|
||||
review_count=_sub.review_count,
|
||||
review_avg_rating=_sub.review_avg_rating,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
"""
|
||||
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
||||
"""
|
||||
if not (_l := _lv.StoreListing):
|
||||
raise ValueError("StoreListingVersion must have included StoreListing")
|
||||
|
||||
return cls(
|
||||
listing_id=_l.id,
|
||||
user_id=_l.owningUserId,
|
||||
slug=_l.slug,
|
||||
listing_version_id=_lv.id,
|
||||
listing_version=_lv.version,
|
||||
graph_id=_lv.agentGraphId,
|
||||
graph_version=_lv.agentGraphVersion,
|
||||
name=_lv.name,
|
||||
sub_heading=_lv.subHeading,
|
||||
description=_lv.description,
|
||||
instructions=_lv.instructions,
|
||||
categories=_lv.categories,
|
||||
image_urls=_lv.imageUrls,
|
||||
video_url=_lv.videoUrl,
|
||||
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
||||
submitted_at=_lv.submittedAt,
|
||||
changes_summary=_lv.changesSummary,
|
||||
status=_lv.submissionStatus,
|
||||
reviewed_at=_lv.reviewedAt,
|
||||
reviewer_id=_lv.reviewerId,
|
||||
review_comments=_lv.reviewComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreListingWithVersions(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
slug: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmission | None = None
|
||||
versions: list[StoreSubmission] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersions]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
graph_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Graph ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
graph_version: int = pydantic.Field(
|
||||
..., gt=0, description="Graph version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
class StoreSubmissionAdminView(StoreSubmission):
|
||||
internal_comments: str | None # Private admin notes
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_db(_sub).model_dump(),
|
||||
internal_comments=_sub.internal_comments,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
||||
internal_comments=_lv.internalComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
graph_id: str
|
||||
slug: str
|
||||
active_listing_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmissionAdminView | None = None
|
||||
versions: list[StoreSubmissionAdminView] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersionsAdminView]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from . import model as store_model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = store_model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_output_demo="demo.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = store_model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = store_model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
from typing import Literal
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
from fastapi import Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.ProfileDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_profile(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Get the profile details for the authenticated user."""
|
||||
profile = await store_db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
raise NotFoundError("User does not have a profile yet")
|
||||
return profile
|
||||
|
||||
|
||||
@@ -57,98 +51,17 @@ async def get_profile(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.CreatorDetails,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: store_model.Profile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Update the store profile for the authenticated user."""
|
||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||
return updated_profile
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
@@ -158,60 +71,30 @@ async def get_agents(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
content_types: list[prisma.enums.ContentType] | None = Query(
|
||||
default=None,
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
description="Content types to search. If not specified, searches all.",
|
||||
),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
user_id: str | None = Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
):
|
||||
) -> store_model.UnifiedSearchResponse:
|
||||
"""
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
Search across all content types (marketplace agents, blocks, documentation)
|
||||
using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_type_enums,
|
||||
content_types=content_types,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -245,22 +128,69 @@ async def unified_search(
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured agents"
|
||||
),
|
||||
creator: str | None = Query(
|
||||
default=None, description="Filter agents by creator username"
|
||||
),
|
||||
category: str | None = Query(default=None, description="Filter agents by category"),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
||||
default=None,
|
||||
description="Property to sort results by. Ignored if search_query is provided.",
|
||||
),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the marketplace,
|
||||
with optional filtering and sorting.
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent(
|
||||
async def get_agent_by_name(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
include_changelog: bool = fastapi.Query(default=False),
|
||||
):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
include_changelog: bool = Query(default=False),
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get details of a marketplace agent"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
@@ -270,76 +200,82 @@ async def get_agent(
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreReview,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_review(
|
||||
async def post_user_review_for_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: store_model.StoreReviewCreate,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreReview:
|
||||
"""Post a user review on a marketplace agent listing"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
|
||||
created_review = await store_db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_agent_by_listing_version(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""Get outline of graph belonging to a specific marketplace listing version"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph/download",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
@@ -349,37 +285,19 @@ async def create_review(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured creators"
|
||||
),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""List or search marketplace creators"""
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
@@ -391,18 +309,12 @@ async def get_creators(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
"/creators/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
- Creator Details Page
|
||||
"""
|
||||
async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||
"""Get details on a marketplace creator"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
@@ -414,20 +326,17 @@ async def get_creator(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/myagents",
|
||||
"/my-unpublished-agents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.MyAgentsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
async def get_my_unpublished_agents(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.MyUnpublishedAgentsResponse:
|
||||
"""List the authenticated user's unpublished agents"""
|
||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
|
||||
@@ -436,28 +345,17 @@ async def get_my_agents(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=bool,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def delete_submission(
|
||||
submission_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> bool:
|
||||
"""Delete a marketplace listing submission"""
|
||||
result = await store_db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -465,37 +363,14 @@ async def delete_submission(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmissionsResponse,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreSubmissionsResponse:
|
||||
"""List the authenticated user's marketplace listing submissions"""
|
||||
listings = await store_db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
@@ -508,30 +383,17 @@ async def get_submissions(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: store_model.StoreSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Submit a new marketplace listing for review"""
|
||||
result = await store_db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
graph_id=submission_request.graph_id,
|
||||
graph_version=submission_request.graph_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
@@ -544,7 +406,6 @@ async def create_submission(
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -552,28 +413,14 @@ async def create_submission(
|
||||
"/submissions/{store_listing_version_id}",
|
||||
summary="Edit store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def edit_submission(
|
||||
store_listing_version_id: str,
|
||||
submission_request: store_model.StoreSubmissionEditRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to edit
|
||||
submission_request (StoreSubmissionRequest): The updated submission details
|
||||
user_id (str): ID of the authenticated user editing the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The updated store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Update a pending marketplace listing submission"""
|
||||
result = await store_db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
@@ -588,7 +435,6 @@ async def edit_submission(
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -596,115 +442,61 @@ async def edit_submission(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> str:
|
||||
"""Upload media for a marketplace listing submission"""
|
||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||
return media_url
|
||||
|
||||
|
||||
class ImageURLResponse(BaseModel):
|
||||
image_url: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def generate_image(
|
||||
agent_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> fastapi.responses.Response:
|
||||
graph_id: str,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> ImageURLResponse:
|
||||
"""
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
Generate an image for a marketplace listing submission based on the properties
|
||||
of a given graph.
|
||||
"""
|
||||
agent = await backend.data.graph.get_graph(
|
||||
graph_id=agent_id, version=None, user_id=user_id
|
||||
graph = await backend.data.graph.get_graph(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
filename = f"agent_{graph_id}.jpeg"
|
||||
|
||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
logger.info(f"Using existing image for agent graph {graph_id}")
|
||||
return ImageURLResponse(image_url=existing_url)
|
||||
# Generate agent image as JPEG
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
return ImageURLResponse(image_url=image_url)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
|
||||
from . import model as store_model
|
||||
from . import routes as store_routes
|
||||
|
||||
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by="runs",
|
||||
sorted_by=StoreAgentsSortOptions.RUNS,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
@@ -380,9 +382,11 @@ def test_get_agent_details(
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
graph_versions=["1", "2"],
|
||||
graph_id="test-graph-id",
|
||||
last_updated=FIXED_NOW,
|
||||
active_version_id="test-version-id",
|
||||
has_approved_version=True,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
|
||||
) -> None:
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
store_model.CreatorDetails(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
description=f"Creator {i} description",
|
||||
links=[f"user{i}.link.com"],
|
||||
is_featured=False,
|
||||
num_agents=1,
|
||||
agent_runs=100,
|
||||
agent_rating=4.5,
|
||||
top_categories=["cat1", "cat2", "cat3"],
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
@@ -496,19 +502,19 @@ def test_get_creator_details(
|
||||
mocked_value = store_model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
avatar_url="avatar.jpg",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
is_featured=True,
|
||||
num_agents=5,
|
||||
agent_runs=1000,
|
||||
agent_rating=4.8,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.store.db.get_store_creator_details"
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
response = client.get("/creators/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = store_model.CreatorDetails.model_validate(response.json())
|
||||
@@ -528,19 +534,26 @@ def test_get_submissions_success(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
date_submitted=FIXED_NOW,
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
user_id="test-user-id",
|
||||
slug="test-agent",
|
||||
video_url="test.mp4",
|
||||
listing_version_id="test-version-id",
|
||||
listing_version=1,
|
||||
graph_id="test-agent-id",
|
||||
graph_version=1,
|
||||
name="Test Agent",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
instructions="Click the button!",
|
||||
categories=["test-category"],
|
||||
image_urls=["test.jpg"],
|
||||
video_url="test.mp4",
|
||||
agent_output_demo_url="demo_video.mp4",
|
||||
submitted_at=FIXED_NOW,
|
||||
changes_summary="Initial Submission",
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
run_count=50,
|
||||
review_count=5,
|
||||
review_avg_rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from .db import StoreAgentsSortOptions
|
||||
from .model import StoreAgent, StoreAgentsResponse
|
||||
|
||||
|
||||
@@ -215,7 +216,7 @@ class TestCacheDeletion:
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -227,7 +228,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -239,7 +240,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
|
||||
@@ -55,6 +55,7 @@ from backend.data.credit import (
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
from backend.data.invited_user import get_or_activate_user
|
||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
|
||||
update_user_onboarding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
user = await get_or_activate_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||
# not produce a stored result before first activation.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
@@ -165,7 +163,8 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
email: str = Body(...),
|
||||
) -> dict[str, str]:
|
||||
await update_user_email(user_id, email)
|
||||
|
||||
@@ -179,10 +178,16 @@ async def update_user_email_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_user_timezone_route(
|
||||
user_data: dict = Security(get_jwt_payload),
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail="User not found. Please complete activation via /auth/user first.",
|
||||
)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -193,7 +198,8 @@ async def get_user_timezone_route(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
request: UpdateTimezoneRequest,
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
@@ -449,7 +455,6 @@ async def execute_graph_block(
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
@@ -512,7 +517,6 @@ async def upload_file(
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.v1.get_or_create_user",
|
||||
"backend.api.features.v1.get_or_activate_user",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
@@ -515,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id=test_user_id,
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
|
||||
@@ -533,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
|
||||
mock_handler.store_file.assert_called_once_with(
|
||||
content=file_content,
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
@@ -120,6 +120,10 @@ class UploadFileResponse(BaseModel):
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
@@ -151,6 +155,31 @@ async def download_file(
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> DeleteFileResponse:
|
||||
"""
|
||||
Soft-delete a workspace file and attempt to remove it from storage.
|
||||
|
||||
Used when a user clears a file input in the builder.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id)
|
||||
deleted = await manager.delete_file(file_id)
|
||||
if not deleted:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return DeleteFileResponse(deleted=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
@@ -218,7 +247,10 @@ async def upload_file(
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
|
||||
@@ -305,3 +305,55 @@ def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"deleted": True}
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
|
||||
@@ -94,3 +94,8 @@ class NotificationPayload(pydantic.BaseModel):
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: OnboardingStep | None
|
||||
|
||||
|
||||
class CopilotCompletionPayload(NotificationPayload):
|
||||
session_id: str
|
||||
status: Literal["completed", "failed"]
|
||||
|
||||
@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.admin.user_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
@@ -55,6 +56,7 @@ from backend.util.exceptions import (
|
||||
MissingConfigError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
@@ -275,6 +277,7 @@ app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
@@ -309,6 +312,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/executions",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.user_admin_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/users",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
|
||||
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
@@ -620,6 +624,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
is_graph_execution: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
@@ -648,6 +653,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
graph_version=graph_version,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
|
||||
@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
||||
disabled=True,
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
|
||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
execution = await sandbox.run_code( # type: ignore[attr-defined]
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
|
||||
@@ -96,6 +96,7 @@ class SendEmailBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
3
autogpt_platform/backend/backend/blocks/github/_utils.py
Normal file
3
autogpt_platform/backend/backend/blocks/github/_utils.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def github_repo_path(repo_url: str) -> str:
|
||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||
return repo_url.replace("https://github.com/", "")
|
||||
408
autogpt_platform/backend/backend/blocks/github/commits.py
Normal file
408
autogpt_platform/backend/backend/blocks/github/commits.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import asyncio
|
||||
from enum import StrEnum
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListCommitsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to list commits from",
|
||||
default="main",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of commits to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class CommitItem(TypedDict):
|
||||
sha: str
|
||||
message: str
|
||||
author: str
|
||||
date: str
|
||||
url: str
|
||||
|
||||
commit: CommitItem = SchemaField(
|
||||
title="Commit", description="A commit with its details"
|
||||
)
|
||||
commits: list[CommitItem] = SchemaField(
|
||||
description="List of commits with their details"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing commits failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||
description="This block lists commits on a branch in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommitsBlock.Input,
|
||||
output_schema=GithubListCommitsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"commits",
|
||||
[
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"commit",
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_commits": lambda *args, **kwargs: [
|
||||
{
|
||||
"sha": "abc123",
|
||||
"message": "Initial commit",
|
||||
"author": "octocat",
|
||||
"date": "2024-01-01T00:00:00Z",
|
||||
"url": "https://github.com/owner/repo/commit/abc123",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_commits(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
per_page: int,
|
||||
page: int,
|
||||
) -> list[Output.CommitItem]:
|
||||
api = get_api(credentials)
|
||||
commits_url = repo_url + "/commits"
|
||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||
response = await api.get(commits_url, params=params)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
return [
|
||||
GithubListCommitsBlock.Output.CommitItem(
|
||||
sha=c["sha"],
|
||||
message=c["commit"]["message"],
|
||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||
)
|
||||
for c in data
|
||||
]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
commits = await self.list_commits(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "commits", commits
|
||||
for commit in commits:
|
||||
yield "commit", commit
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class FileOperation(StrEnum):
|
||||
"""File operations for GithubMultiFileCommitBlock.
|
||||
|
||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||
between creation and update — the blob is placed at the given path regardless
|
||||
of whether a file already exists there).
|
||||
|
||||
DELETE removes a file from the tree.
|
||||
"""
|
||||
|
||||
UPSERT = "upsert"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
class GithubMultiFileCommitBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to commit to",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Commit message",
|
||||
placeholder="Add new feature",
|
||||
)
|
||||
files: list[FileOperationInput] = SchemaField(
|
||||
description=(
|
||||
"List of file operations. Each item has: "
|
||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||
"'operation' (upsert/delete)"
|
||||
),
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the new commit")
|
||||
url: str = SchemaField(description="URL of the new commit")
|
||||
error: str = SchemaField(description="Error message if the commit failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||
description=(
|
||||
"This block creates a single commit with multiple file "
|
||||
"upsert/delete operations using the Git Trees API."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMultiFileCommitBlock.Input,
|
||||
output_schema=GithubMultiFileCommitBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "Add files",
|
||||
"files": [
|
||||
{
|
||||
"path": "src/new.py",
|
||||
"content": "print('hello')",
|
||||
"operation": "upsert",
|
||||
},
|
||||
{
|
||||
"path": "src/old.py",
|
||||
"content": "",
|
||||
"operation": "delete",
|
||||
},
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "newcommitsha"),
|
||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||
],
|
||||
test_mock={
|
||||
"multi_file_commit": lambda *args, **kwargs: (
|
||||
"newcommitsha",
|
||||
"https://github.com/owner/repo/commit/newcommitsha",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def multi_file_commit(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
files: list[FileOperationInput],
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
safe_branch = quote(branch, safe="")
|
||||
|
||||
# 1. Get the latest commit SHA for the branch
|
||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||
response = await api.get(ref_url)
|
||||
ref_data = response.json()
|
||||
latest_commit_sha = ref_data["object"]["sha"]
|
||||
|
||||
# 2. Get the tree SHA of the latest commit
|
||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||
response = await api.get(commit_url)
|
||||
commit_data = response.json()
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
tree_entries: list[dict] = []
|
||||
upsert_files = []
|
||||
for file_op in files:
|
||||
path = file_op["path"]
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
|
||||
if operation == FileOperation.DELETE:
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": None, # null SHA = delete
|
||||
}
|
||||
)
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Create a new tree
|
||||
tree_url = repo_url + "/git/trees"
|
||||
tree_response = await api.post(
|
||||
tree_url,
|
||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||
)
|
||||
new_tree_sha = tree_response.json()["sha"]
|
||||
|
||||
# 5. Create a new commit
|
||||
new_commit_url = repo_url + "/git/commits"
|
||||
commit_response = await api.post(
|
||||
new_commit_url,
|
||||
json={
|
||||
"message": commit_message,
|
||||
"tree": new_tree_sha,
|
||||
"parents": [latest_commit_sha],
|
||||
},
|
||||
)
|
||||
new_commit_sha = commit_response.json()["sha"]
|
||||
|
||||
# 6. Update the branch reference
|
||||
try:
|
||||
await api.patch(
|
||||
ref_url,
|
||||
json={"sha": new_commit_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Commit {new_commit_sha} was created but failed to update "
|
||||
f"ref heads/{branch}: {e}. "
|
||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||
) from e
|
||||
|
||||
repo_path = github_repo_path(repo_url)
|
||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||
return new_commit_sha, commit_web_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -20,6 +21,8 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
@@ -558,12 +561,109 @@ class GithubListPRReviewersBlock(Block):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
class GithubMergePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
merge_method: MergeMethod = SchemaField(
|
||||
description="Merge method to use: merge, squash, or rebase",
|
||||
default="merge",
|
||||
)
|
||||
commit_title: str = SchemaField(
|
||||
description="Title for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the merge commit (optional, used for merge and squash)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
sha: str = SchemaField(description="SHA of the merge commit")
|
||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||
message: str = SchemaField(description="Merge status message")
|
||||
error: str = SchemaField(description="Error message if the merge failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||
description="This block merges a pull request using merge, squash, or rebase.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMergePullRequestBlock.Input,
|
||||
output_schema=GithubMergePullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("sha", "abc123"),
|
||||
("merged", True),
|
||||
("message", "Pull Request successfully merged"),
|
||||
],
|
||||
test_mock={
|
||||
"merge_pr": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
True,
|
||||
"Pull Request successfully merged",
|
||||
)
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def merge_pr(
|
||||
credentials: GithubCredentials,
|
||||
pr_url: str,
|
||||
merge_method: MergeMethod,
|
||||
commit_title: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, bool, str]:
|
||||
api = get_api(credentials)
|
||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||
data: dict[str, str] = {"merge_method": merge_method}
|
||||
if commit_title:
|
||||
data["commit_title"] = commit_title
|
||||
if commit_message:
|
||||
data["commit_message"] = commit_message
|
||||
response = await api.put(merge_url, json=data)
|
||||
result = response.json()
|
||||
return result["sha"], result["merged"], result["message"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sha, merged, message = await self.merge_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.merge_method,
|
||||
input_data.commit_title,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "merged", merged
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
scheme, base_url, pr_number = match.groups()
|
||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -19,6 +17,7 @@ from ._auth import (
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
@@ -89,7 +88,7 @@ class GithubListTagsBlock(Block):
|
||||
tags_url = repo_url + "/tags"
|
||||
response = await api.get(tags_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
repo_path = github_repo_path(repo_url)
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
@@ -115,101 +114,6 @@ class GithubListTagsBlock(Block):
|
||||
yield "tag", tag
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(branches_url)
|
||||
data = response.json()
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -283,7 +187,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
) -> list[Output.DiscussionItem]:
|
||||
api = get_api(credentials)
|
||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
@@ -416,564 +320,6 @@ class GithubListReleasesBlock(Block):
|
||||
yield "release", release
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -1103,7 +449,7 @@ class GithubListStargazersBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
@@ -1172,3 +518,230 @@ class GithubListStargazersBlock(Block):
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
|
||||
class GithubGetRepositoryInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
name: str = SchemaField(description="Repository name")
|
||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||
description: str = SchemaField(description="Repository description")
|
||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||
private: bool = SchemaField(description="Whether the repository is private")
|
||||
html_url: str = SchemaField(description="Web URL of the repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL")
|
||||
stars: int = SchemaField(description="Number of stars")
|
||||
forks: int = SchemaField(description="Number of forks")
|
||||
open_issues: int = SchemaField(description="Number of open issues")
|
||||
error: str = SchemaField(
|
||||
description="Error message if fetching repo info failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||
description="This block retrieves metadata about a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("name", "repo"),
|
||||
("full_name", "owner/repo"),
|
||||
("description", "A test repo"),
|
||||
("default_branch", "main"),
|
||||
("private", False),
|
||||
("html_url", "https://github.com/owner/repo"),
|
||||
("clone_url", "https://github.com/owner/repo.git"),
|
||||
("stars", 42),
|
||||
("forks", 5),
|
||||
("open_issues", 3),
|
||||
],
|
||||
test_mock={
|
||||
"get_repo_info": lambda *args, **kwargs: {
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"description": "A test repo",
|
||||
"default_branch": "main",
|
||||
"private": False,
|
||||
"html_url": "https://github.com/owner/repo",
|
||||
"clone_url": "https://github.com/owner/repo.git",
|
||||
"stargazers_count": 42,
|
||||
"forks_count": 5,
|
||||
"open_issues_count": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||
api = get_api(credentials)
|
||||
response = await api.get(repo_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||
yield "name", data["name"]
|
||||
yield "full_name", data["full_name"]
|
||||
yield "description", data.get("description", "") or ""
|
||||
yield "default_branch", data["default_branch"]
|
||||
yield "private", data["private"]
|
||||
yield "html_url", data["html_url"]
|
||||
yield "clone_url", data["clone_url"]
|
||||
yield "stars", data["stargazers_count"]
|
||||
yield "forks", data["forks_count"]
|
||||
yield "open_issues", data["open_issues_count"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubForkRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to fork",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
organization: str = SchemaField(
|
||||
description="Organization to fork into (leave empty to fork to your account)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the forked repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||
error: str = SchemaField(description="Error message if the fork failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||
description="This block forks a GitHub repository to your account or an organization.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubForkRepositoryBlock.Input,
|
||||
output_schema=GithubForkRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"organization": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/myuser/repo"),
|
||||
("clone_url", "https://github.com/myuser/repo.git"),
|
||||
("full_name", "myuser/repo"),
|
||||
],
|
||||
test_mock={
|
||||
"fork_repo": lambda *args, **kwargs: (
|
||||
"https://github.com/myuser/repo",
|
||||
"https://github.com/myuser/repo.git",
|
||||
"myuser/repo",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def fork_repo(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
organization: str,
|
||||
) -> tuple[str, str, str]:
|
||||
api = get_api(credentials)
|
||||
forks_url = repo_url + "/forks"
|
||||
data: dict[str, str] = {}
|
||||
if organization:
|
||||
data["organization"] = organization
|
||||
response = await api.post(forks_url, json=data)
|
||||
result = response.json()
|
||||
return result["html_url"], result["clone_url"], result["full_name"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url, full_name = await self.fork_repo(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.organization,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
yield "full_name", full_name
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubStarRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository to star",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the star operation")
|
||||
error: str = SchemaField(description="Error message if starring failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||
description="This block stars a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubStarRepositoryBlock.Input,
|
||||
output_schema=GithubStarRepositoryBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Repository starred successfully")],
|
||||
test_mock={
|
||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
repo_path = github_repo_path(repo_url)
|
||||
owner, repo = repo_path.split("/")
|
||||
await api.put(
|
||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||
headers={"Content-Length": "0"},
|
||||
)
|
||||
return "Repository starred successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.star_repo(credentials, input_data.repo_url)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
452
autogpt_platform/backend/backend/blocks/github/repo_branches.py
Normal file
452
autogpt_platform/backend/backend/blocks/github/repo_branches.py
Normal file
@@ -0,0 +1,452 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
from ._utils import github_repo_path
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of branches to return per page (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"per_page": 30,
|
||||
"page": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||
) -> list[Output.BranchItem]:
|
||||
api = get_api(credentials)
|
||||
branches_url = repo_url + "/branches"
|
||||
response = await api.get(
|
||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||
)
|
||||
data = response.json()
|
||||
repo_path = github_repo_path(repo_url)
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{
|
||||
"name": branch["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||
}
|
||||
for branch in data
|
||||
]
|
||||
return branches
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = await self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.per_page,
|
||||
input_data.page,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||
response = await api.get(ref_url)
|
||||
data = response.json()
|
||||
sha = data["object"]["sha"]
|
||||
|
||||
# Create the new branch
|
||||
new_ref_url = repo_url + "/git/refs"
|
||||
data = {
|
||||
"ref": f"refs/heads/{new_branch}",
|
||||
"sha": sha,
|
||||
}
|
||||
response = await api.post(new_ref_url, json=data)
|
||||
return "Branch created successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||
await api.delete(ref_url)
|
||||
return "Branch deleted successfully"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = await self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCompareBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="Base branch or commit SHA",
|
||||
placeholder="main",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="Head branch or commit SHA to compare against base",
|
||||
placeholder="feature-branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class FileChange(TypedDict):
|
||||
filename: str
|
||||
status: str
|
||||
additions: int
|
||||
deletions: int
|
||||
patch: str
|
||||
|
||||
status: str = SchemaField(
|
||||
description="Comparison status: ahead, behind, diverged, or identical"
|
||||
)
|
||||
ahead_by: int = SchemaField(
|
||||
description="Number of commits head is ahead of base"
|
||||
)
|
||||
behind_by: int = SchemaField(
|
||||
description="Number of commits head is behind base"
|
||||
)
|
||||
total_commits: int = SchemaField(
|
||||
description="Total number of commits in the comparison"
|
||||
)
|
||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||
file: FileChange = SchemaField(
|
||||
title="Changed File", description="A changed file with its diff"
|
||||
)
|
||||
files: list[FileChange] = SchemaField(
|
||||
description="List of changed files with their diffs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comparison failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||
description="This block compares two branches or commits in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCompareBranchesBlock.Input,
|
||||
output_schema=GithubCompareBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"base": "main",
|
||||
"head": "feature",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("status", "ahead"),
|
||||
("ahead_by", 2),
|
||||
("behind_by", 0),
|
||||
("total_commits", 2),
|
||||
("diff", "+++ b/file.py\n+new line"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"compare_branches": lambda *args, **kwargs: {
|
||||
"status": "ahead",
|
||||
"ahead_by": 2,
|
||||
"behind_by": 0,
|
||||
"total_commits": 2,
|
||||
"files": [
|
||||
{
|
||||
"filename": "file.py",
|
||||
"status": "modified",
|
||||
"additions": 1,
|
||||
"deletions": 0,
|
||||
"patch": "+new line",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def compare_branches(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
base: str,
|
||||
head: str,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
safe_base = quote(base, safe="")
|
||||
safe_head = quote(head, safe="")
|
||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||
response = await api.get(compare_url)
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data = await self.compare_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.base,
|
||||
input_data.head,
|
||||
)
|
||||
yield "status", data["status"]
|
||||
yield "ahead_by", data["ahead_by"]
|
||||
yield "behind_by", data["behind_by"]
|
||||
yield "total_commits", data["total_commits"]
|
||||
|
||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||
GithubCompareBranchesBlock.Output.FileChange(
|
||||
filename=f["filename"],
|
||||
status=f["status"],
|
||||
additions=f["additions"],
|
||||
deletions=f["deletions"],
|
||||
patch=f.get("patch", ""),
|
||||
)
|
||||
for f in data.get("files", [])
|
||||
]
|
||||
|
||||
# Build unified diff
|
||||
diff_parts = []
|
||||
for f in data.get("files", []):
|
||||
patch = f.get("patch", "")
|
||||
if patch:
|
||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||
yield "diff", "\n".join(diff_parts)
|
||||
|
||||
yield "files", files
|
||||
for file in files:
|
||||
yield "file", file
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
720
autogpt_platform/backend/backend/blocks/github/repo_files.py
Normal file
720
autogpt_platform/backend/backend/blocks/github/repo_files.py
Normal file
@@ -0,0 +1,720 @@
|
||||
import base64
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if reading the file failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
api = get_api(credentials)
|
||||
content_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(content_url)
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
data = file
|
||||
|
||||
if data["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return data["content"], data["size"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
content, size = await self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", content
|
||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to main)",
|
||||
placeholder="branch_name",
|
||||
default="main",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
api = get_api(credentials)
|
||||
contents_url = (
|
||||
repo_url
|
||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||
)
|
||||
response = await api.get(contents_url)
|
||||
data = response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "file"
|
||||
]
|
||||
|
||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in data
|
||||
if entry["type"] == "dir"
|
||||
]
|
||||
|
||||
return files, dirs
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = await self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
for file in files:
|
||||
yield "file", file
|
||||
for dir in dirs:
|
||||
yield "dir", dir
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||
params = {"ref": branch}
|
||||
response = await api.get(contents_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_base64 = base64.b64encode(content.encode()).decode()
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": data["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = await api.put(contents_url, json=data)
|
||||
data = response.json()
|
||||
return data["content"]["html_url"], data["commit"]["sha"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = await self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubSearchCodeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
query: str = SchemaField(
|
||||
description="Search query (GitHub code search syntax)",
|
||||
placeholder="className language:python",
|
||||
)
|
||||
repo: str = SchemaField(
|
||||
description="Restrict search to a repository (owner/repo format, optional)",
|
||||
default="",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=30,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class SearchResult(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
repository: str
|
||||
url: str
|
||||
score: float
|
||||
|
||||
result: SearchResult = SchemaField(
|
||||
title="Result", description="A code search result"
|
||||
)
|
||||
results: list[SearchResult] = SchemaField(
|
||||
description="List of code search results"
|
||||
)
|
||||
total_count: int = SchemaField(description="Total number of matching results")
|
||||
error: str = SchemaField(description="Error message if search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||
description="This block searches for code in GitHub repositories.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSearchCodeBlock.Input,
|
||||
output_schema=GithubSearchCodeBlock.Output,
|
||||
test_input={
|
||||
"query": "addClass",
|
||||
"repo": "owner/repo",
|
||||
"per_page": 30,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_count", 1),
|
||||
(
|
||||
"results",
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_code": lambda *args, **kwargs: (
|
||||
1,
|
||||
[
|
||||
{
|
||||
"name": "file.py",
|
||||
"path": "src/file.py",
|
||||
"repository": "owner/repo",
|
||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||
"score": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_code(
|
||||
credentials: GithubCredentials,
|
||||
query: str,
|
||||
repo: str,
|
||||
per_page: int,
|
||||
) -> tuple[int, list[Output.SearchResult]]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
full_query = f"{query} repo:{repo}" if repo else query
|
||||
params = {"q": full_query, "per_page": str(per_page)}
|
||||
response = await api.get("https://api.github.com/search/code", params=params)
|
||||
data = response.json()
|
||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||
GithubSearchCodeBlock.Output.SearchResult(
|
||||
name=item["name"],
|
||||
path=item["path"],
|
||||
repository=item["repository"]["full_name"],
|
||||
url=item["html_url"],
|
||||
score=item["score"],
|
||||
)
|
||||
for item in data["items"]
|
||||
]
|
||||
return data["total_count"], results
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
total_count, results = await self.search_code(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.repo,
|
||||
input_data.per_page,
|
||||
)
|
||||
yield "total_count", total_count
|
||||
yield "results", results
|
||||
for result in results:
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetRepositoryTreeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to get the tree from",
|
||||
default="main",
|
||||
)
|
||||
recursive: bool = SchemaField(
|
||||
description="Whether to recursively list the entire tree",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class TreeEntry(TypedDict):
|
||||
path: str
|
||||
type: str
|
||||
size: int
|
||||
sha: str
|
||||
|
||||
entry: TreeEntry = SchemaField(
|
||||
title="Tree Entry", description="A file or directory in the tree"
|
||||
)
|
||||
entries: list[TreeEntry] = SchemaField(
|
||||
description="List of all files and directories in the tree"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the tree was truncated due to size"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting tree failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "main",
|
||||
"recursive": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("truncated", False),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"entry",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tree": lambda *args, **kwargs: (
|
||||
False,
|
||||
[
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"type": "blob",
|
||||
"size": 1234,
|
||||
"sha": "abc123",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_tree(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
branch: str,
|
||||
recursive: bool,
|
||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||
api = get_api(credentials)
|
||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||
params = {"recursive": "1"} if recursive else {}
|
||||
response = await api.get(tree_url, params=params)
|
||||
data = response.json()
|
||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||
path=item["path"],
|
||||
type=item["type"],
|
||||
size=item.get("size", 0),
|
||||
sha=item["sha"],
|
||||
)
|
||||
for item in data["tree"]
|
||||
]
|
||||
return data.get("truncated", False), entries
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
truncated, entries = await self.get_tree(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.recursive,
|
||||
)
|
||||
yield "truncated", truncated
|
||||
yield "entries", entries
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -0,0 +1,125 @@
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||
from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
|
||||
|
||||
class TestPreparePrApiUrl:
|
||||
def test_https_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||
|
||||
def test_http_scheme_preserved(self):
|
||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||
|
||||
def test_no_scheme_defaults_to_https(self):
|
||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||
|
||||
def test_reviewers_path(self):
|
||||
result = prepare_pr_api_url(
|
||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||
)
|
||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||
|
||||
def test_invalid_url_returned_as_is(self):
|
||||
url = "https://example.com/not-a-pr"
|
||||
assert prepare_pr_api_url(url, "merge") == url
|
||||
|
||||
def test_empty_string(self):
|
||||
assert prepare_pr_api_url("", "merge") == ""
|
||||
|
||||
|
||||
# ── Error-path block tests ──
|
||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||
# which returns early on empty test_output).
|
||||
|
||||
|
||||
def _mock_block(block, mocks: dict):
|
||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||
for name, mock_fn in mocks.items():
|
||||
original = getattr(block, name)
|
||||
if inspect.iscoroutinefunction(original):
|
||||
|
||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||
return _fn(*args, **kwargs)
|
||||
|
||||
setattr(block, name, async_mock)
|
||||
else:
|
||||
setattr(block, name, mock_fn)
|
||||
|
||||
|
||||
def _raise(exc: Exception):
|
||||
"""Helper that returns a callable which raises the given exception."""
|
||||
|
||||
def _raiser(*args, **kwargs):
|
||||
raise exc
|
||||
|
||||
return _raiser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_error_path():
|
||||
block = GithubMergePullRequestBlock()
|
||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||
input_data = {
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"merge_method": "squash",
|
||||
"commit_title": "",
|
||||
"commit_message": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_file_commit_error_path():
|
||||
block = GithubMultiFileCommitBlock()
|
||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||
input_data = {
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "feature",
|
||||
"commit_message": "test",
|
||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# ── FileOperation enum tests ──
|
||||
|
||||
|
||||
class TestFileOperation:
|
||||
def test_upsert_value(self):
|
||||
assert FileOperation.UPSERT == "upsert"
|
||||
|
||||
def test_delete_value(self):
|
||||
assert FileOperation.DELETE == "delete"
|
||||
|
||||
def test_invalid_value_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("create")
|
||||
|
||||
def test_invalid_value_raises_typo(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileOperation("upser")
|
||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
||||
h.ignore_links = False
|
||||
h.ignore_images = True
|
||||
return h.handle(html_content)
|
||||
except ImportError:
|
||||
# Fallback: return raw HTML if html2text is not available
|
||||
except Exception:
|
||||
# Keep extraction resilient if html2text is unavailable or fails.
|
||||
return html_content
|
||||
|
||||
# Handle content stored as attachment
|
||||
|
||||
@@ -67,6 +67,7 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
@@ -143,10 +144,11 @@ class HITLReviewHelper:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
if is_graph_execution:
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
@@ -168,6 +170,7 @@ class HITLReviewHelper:
|
||||
graph_version: int,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
is_graph_execution: bool = True,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
@@ -197,6 +200,7 @@ class HITLReviewHelper:
|
||||
graph_version=graph_version,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
is_graph_execution=is_graph_execution,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
try:
|
||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||
parsed_url, _, _ = await validate_url_host(input_data.url)
|
||||
url = parsed_url.geturl()
|
||||
except ValueError as e:
|
||||
yield "error", f"Invalid URL: {e}"
|
||||
|
||||
@@ -31,10 +31,14 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.request import validate_url_host
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
settings = Settings()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
@@ -116,6 +120,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -135,19 +140,31 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||
CODESTRAL = "mistralai/codestral-2508"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -155,9 +172,11 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_3 = "x-ai/grok-3"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
@@ -274,6 +293,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-6
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
@@ -332,17 +354,41 @@ MODEL_METADATA = {
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1050000,
|
||||
8192,
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 2.5 Pro",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Pro Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3 Flash Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
@@ -350,6 +396,15 @@ MODEL_METADATA = {
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65536,
|
||||
"Gemini 3.1 Flash Lite Preview",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
@@ -371,12 +426,78 @@ MODEL_METADATA = {
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
None,
|
||||
"Mistral Large 3 2512",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
None,
|
||||
"Mistral Medium 3.1",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
2,
|
||||
),
|
||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Mistral Small 3.2 24B",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.CODESTRAL: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
None,
|
||||
"Codestral 2508",
|
||||
"OpenRouter",
|
||||
"Mistral AI",
|
||||
1,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Translate 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||
"open_router",
|
||||
256000,
|
||||
32768,
|
||||
"Command A Reasoning 08.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
3,
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8192,
|
||||
"Command A Vision 07.2025",
|
||||
"OpenRouter",
|
||||
"Cohere",
|
||||
2,
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
@@ -389,6 +510,15 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
8000,
|
||||
"Sonar Reasoning Pro",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
2,
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
@@ -434,6 +564,9 @@ MODEL_METADATA = {
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
@@ -443,6 +576,15 @@ MODEL_METADATA = {
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_3: ModelMetadata(
|
||||
"open_router",
|
||||
131072,
|
||||
131072,
|
||||
"Grok 3",
|
||||
"OpenRouter",
|
||||
"xAI",
|
||||
2,
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
@@ -800,6 +942,11 @@ async def llm_call(
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
# Validate user-provided Ollama host to prevent SSRF etc.
|
||||
await validate_url_host(
|
||||
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
||||
)
|
||||
|
||||
client = ollama.AsyncClient(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
@@ -821,7 +968,7 @@ async def llm_call(
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr
|
||||
from pydantic import SecretStr, field_validator
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -13,6 +13,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -21,6 +22,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
@@ -34,6 +36,20 @@ class PerplexityModel(str, Enum):
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||
if isinstance(value, PerplexityModel):
|
||||
return value
|
||||
try:
|
||||
return PerplexityModel(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PerplexityModel '{value}', "
|
||||
f"falling back to {PerplexityModel.SONAR.value}"
|
||||
)
|
||||
return PerplexityModel.SONAR
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
@@ -72,6 +88,25 @@ class PerplexityBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
|
||||
@field_validator("model", mode="before")
|
||||
@classmethod
|
||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||
"""Fall back to SONAR if the model value is not a valid
|
||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||
generator)."""
|
||||
return _sanitize_perplexity_model(v)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
"""Sanitize the model field before JSON schema validation so that
|
||||
invalid values are replaced with the default instead of raising a
|
||||
BlockInputError."""
|
||||
model_value = data.get("model")
|
||||
if model_value is not None:
|
||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||
return super().validate_data(data)
|
||||
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
@@ -136,7 +171,7 @@ class PerplexityBlock(Block):
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -2232,6 +2232,7 @@ class DeleteRedditPostBlock(Block):
|
||||
("post_id", "abc123"),
|
||||
],
|
||||
test_mock={"delete_post": lambda creds, post_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -2290,6 +2291,7 @@ class DeleteRedditCommentBlock(Block):
|
||||
("comment_id", "xyz789"),
|
||||
],
|
||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -72,6 +72,7 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
is_sensitive_action=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -83,7 +83,8 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
@@ -137,7 +138,7 @@ class StagehandObserveBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -227,7 +228,7 @@ class StagehandActBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -324,7 +325,7 @@ class StagehandExtractBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
sort_by: StoreAgentsSortOptions = SchemaField(
|
||||
description="How to sort the results", default=StoreAgentsSortOptions.RATING
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
|
||||
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.perplexity import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
PerplexityBlock,
|
||||
PerplexityModel,
|
||||
)
|
||||
|
||||
|
||||
def _make_input(**overrides) -> dict:
|
||||
defaults = {
|
||||
"prompt": "test query",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return defaults
|
||||
|
||||
|
||||
class TestPerplexityModelFallback:
|
||||
"""Tests for fallback_invalid_model field_validator."""
|
||||
|
||||
def test_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
def test_valid_model_string_is_kept(self):
|
||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||
assert inp.model == PerplexityModel.SONAR_PRO
|
||||
|
||||
def test_valid_enum_value_is_kept(self):
|
||||
inp = PerplexityBlock.Input(
|
||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||
)
|
||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||
|
||||
def test_default_model_when_omitted(self):
|
||||
inp = PerplexityBlock.Input(**_make_input())
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_value",
|
||||
[
|
||||
"perplexity/sonar",
|
||||
"perplexity/sonar-pro",
|
||||
"perplexity/sonar-deep-research",
|
||||
],
|
||||
)
|
||||
def test_all_valid_models_accepted(self, model_value: str):
|
||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||
assert inp.model.value == model_value
|
||||
|
||||
|
||||
class TestPerplexityValidateData:
|
||||
"""Tests for validate_data which runs during block execution (before
|
||||
Pydantic instantiation). Invalid models must be sanitized here so
|
||||
JSON schema validation does not reject them."""
|
||||
|
||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == PerplexityModel.SONAR.value
|
||||
|
||||
def test_valid_model_unchanged_by_validate_data(self):
|
||||
data = _make_input(model="perplexity/sonar-pro")
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
assert data["model"] == "perplexity/sonar-pro"
|
||||
|
||||
def test_missing_model_uses_default(self):
|
||||
data = _make_input() # no model key
|
||||
error = PerplexityBlock.Input.validate_data(data)
|
||||
assert error is None
|
||||
inp = PerplexityBlock.Input(**data)
|
||||
assert inp.model == PerplexityModel.SONAR
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
@@ -121,7 +122,10 @@ async def test_search_store_agents_block(mocker):
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
query="test",
|
||||
category="productivity",
|
||||
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
|
||||
limit=10,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
|
||||
@@ -13,6 +13,7 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -21,6 +22,7 @@ from backend.copilot.model import (
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -38,7 +40,7 @@ from backend.copilot.response_model import (
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
client,
|
||||
_get_openai_client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
@@ -61,8 +63,8 @@ async def _update_title_async(
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
@@ -87,7 +89,7 @@ async def _compress_session_messages(
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
client=_get_openai_client(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
@@ -175,14 +177,17 @@ async def stream_chat_completion_baseline(
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
@@ -198,6 +203,20 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
# this request are grouped under a single trace with proper attribution.
|
||||
_trace_ctx: Any = None
|
||||
try:
|
||||
_trace_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-baseline",
|
||||
tags=["baseline"],
|
||||
)
|
||||
_trace_ctx.__enter__()
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
@@ -216,7 +235,7 @@ async def stream_chat_completion_baseline(
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
@@ -272,7 +291,7 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
@@ -282,7 +301,7 @@ async def stream_chat_completion_baseline(
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
@@ -385,6 +404,13 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
default=OPENROUTER_BASE_URL,
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
@@ -62,6 +65,10 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
@@ -87,6 +94,10 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -104,9 +115,37 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=43200, # 12 hours — same as session_ttl
|
||||
description="E2B sandbox keepalive timeout in seconds.",
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
)
|
||||
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
|
||||
default="pause",
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
|
||||
Single source of truth for "should we use E2B right now?".
|
||||
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
|
||||
separately at call sites.
|
||||
"""
|
||||
return self.use_e2b_sandbox and bool(self.e2b_api_key)
|
||||
|
||||
@property
|
||||
def active_e2b_api_key(self) -> str | None:
|
||||
"""Return the E2B API key when E2B is enabled and configured, else None.
|
||||
|
||||
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
|
||||
Use in callers::
|
||||
|
||||
if api_key := config.active_e2b_api_key:
|
||||
# E2B is active; api_key is narrowed to str
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
@@ -121,7 +160,7 @@ class ChatConfig(BaseSettings):
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
"""Get E2B API key from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
|
||||
return v
|
||||
|
||||
@@ -129,7 +168,7 @@ class ChatConfig(BaseSettings):
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
@@ -139,13 +178,16 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
|
||||
# The SDK CLI picks it up from the env directly. Including it
|
||||
# would pair it with the OpenRouter base_url, causing auth failures.
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if v is None:
|
||||
if not v:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
@@ -153,7 +195,7 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
v = os.getenv("OPENAI_BASE_URL")
|
||||
if not v:
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@@ -167,6 +209,15 @@ class ChatConfig(BaseSettings):
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
38
autogpt_platform/backend/backend/copilot/config_test.py
Normal file
38
autogpt_platform/backend/backend/copilot/config_test.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Unit tests for ChatConfig."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
def test_both_enabled_and_key_present_returns_true(self):
|
||||
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is True
|
||||
|
||||
def test_enabled_but_missing_key_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
|
||||
assert cfg.e2b_active is False
|
||||
|
||||
def test_disabled_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is False
|
||||
@@ -6,6 +6,32 @@
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
# Used to distinguish CoPilot-generated records from real graph execution records
|
||||
# in PendingHumanReview and other tables.
|
||||
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
|
||||
|
||||
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
|
||||
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
|
||||
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
|
||||
|
||||
# Separator used in synthetic node_exec_id to encode node_id.
|
||||
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
|
||||
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
|
||||
|
||||
def is_copilot_synthetic_id(id_value: str) -> bool:
|
||||
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
|
||||
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
"""Extract node_id from a synthetic node_exec_id.
|
||||
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
128
autogpt_platform/backend/backend/copilot/context.py
Normal file
128
autogpt_platform/backend/backend/copilot/context.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Shared execution context for copilot SDK tool handlers.
|
||||
|
||||
All context variables and their accessors live here so that
|
||||
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
|
||||
without creating circular dependencies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
# validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Return the current (user_id, session) pair for the active request."""
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK working directory for the current session (empty string if unset)."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
163
autogpt_platform/backend/backend/copilot/context_test.py
Normal file
163
autogpt_platform/backend/backend/copilot/context_test.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for context.py — execution context variables and path helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = "test-session"
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context variable getters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_execution_context_defaults():
|
||||
"""get_execution_context returns (None, session) when user_id is not set."""
|
||||
set_execution_context(None, _make_session())
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id is None
|
||||
assert session is not None
|
||||
|
||||
|
||||
def test_set_and_get_execution_context():
|
||||
"""set_execution_context stores user_id and session."""
|
||||
mock_session = _make_session()
|
||||
set_execution_context("user-abc", mock_session)
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id == "user-abc"
|
||||
assert session is mock_session
|
||||
|
||||
|
||||
def test_get_current_sandbox_none_by_default():
|
||||
"""get_current_sandbox returns None when no sandbox is set."""
|
||||
set_execution_context("u1", _make_session(), sandbox=None)
|
||||
assert get_current_sandbox() is None
|
||||
|
||||
|
||||
def test_get_current_sandbox_returns_set_value():
|
||||
"""get_current_sandbox returns the sandbox set via set_execution_context."""
|
||||
mock_sandbox = MagicMock()
|
||||
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
|
||||
def test_get_sdk_cwd_returns_set_value():
|
||||
"""get_sdk_cwd returns the value set via set_execution_context."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_allowed_local_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_allowed_local_path_empty():
|
||||
assert not is_allowed_local_path("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_inside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
path = os.path.join(cwd, "file.txt")
|
||||
assert is_allowed_local_path(path, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sdk_cwd_itself():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert is_allowed_local_path(cwd, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_outside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert not is_allowed_local_path("/etc/passwd", cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
"""Without sdk_cwd or project_dir, all paths are rejected."""
|
||||
_current_project_dir.set("")
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_valid():
|
||||
assert (
|
||||
resolve_sandbox_path("/home/user/project/main.py")
|
||||
== "/home/user/project/main.py"
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_relative():
|
||||
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_workdir_itself():
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_normalizes_dots():
|
||||
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
@@ -81,6 +81,35 @@ async def update_chat_session(
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def update_chat_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
Always filters by (session_id, user_id) so callers cannot mutate another
|
||||
user's session even when they know the session_id.
|
||||
|
||||
Args:
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
guard so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns True if a row was updated, False otherwise (session not found,
|
||||
wrong user, or — when only_if_empty — title was already set).
|
||||
"""
|
||||
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
||||
if only_if_empty:
|
||||
where["title"] = None
|
||||
result = await PrismaChatSession.prisma().update_many(
|
||||
where=where,
|
||||
data={"title": title, "updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
return result > 0
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
|
||||
@@ -6,6 +6,8 @@ in a thread-local context, following the graph executor pattern.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
@@ -108,8 +110,41 @@ class CoPilotProcessor:
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
# Skip the SDK's per-request CLI version check — the bundled CLI is
|
||||
# already version-matched to the SDK package.
|
||||
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
|
||||
|
||||
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
|
||||
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
|
||||
self._prewarm_cli()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def _prewarm_cli(self) -> None:
|
||||
"""Run the bundled CLI binary once to warm OS page caches."""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if cli_path:
|
||||
result = subprocess.run(
|
||||
[cli_path, "-v"],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
|
||||
result.returncode, # type: ignore[reportCallIssue]
|
||||
cli_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
@@ -208,9 +243,10 @@ class CoPilotProcessor:
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = await is_feature_enabled(
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
@@ -228,6 +264,8 @@ class CoPilotProcessor:
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
|
||||
@@ -469,8 +469,16 @@ async def upsert_chat_session(
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
@@ -685,24 +693,34 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
"""
|
||||
try:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
@@ -714,9 +732,8 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
138
autogpt_platform/backend/backend/copilot/optimize_blocks.py
Normal file
138
autogpt_platform/backend/backend/copilot/optimize_blocks.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Scheduler job to generate LLM-optimized block descriptions.
|
||||
|
||||
Runs periodically to rewrite block descriptions into concise, actionable
|
||||
summaries that help the copilot LLM pick the right blocks during agent
|
||||
generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.util.clients import get_database_manager_client, get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a technical writer for an automation platform. "
|
||||
"Rewrite the following block description to be concise (under 50 words), "
|
||||
"informative, and actionable. Focus on what the block does and when to "
|
||||
"use it. Output ONLY the rewritten description, nothing else. "
|
||||
"Do not use markdown formatting."
|
||||
)
|
||||
|
||||
# Rate-limit delay between sequential LLM calls (seconds)
|
||||
_RATE_LIMIT_DELAY = 0.5
|
||||
# Maximum tokens for optimized description generation
|
||||
_MAX_DESCRIPTION_TOKENS = 150
|
||||
# Model for generating optimized descriptions (fast, cheap)
|
||||
_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
|
||||
"""Call the shared OpenAI client to rewrite each block description."""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.error(
|
||||
"No OpenAI client configured, skipping block description optimization"
|
||||
)
|
||||
return {}
|
||||
|
||||
results: dict[str, str] = {}
|
||||
for block in blocks:
|
||||
block_id = block["id"]
|
||||
block_name = block["name"]
|
||||
description = block["description"]
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Block name: {block_name}\nDescription: {description}",
|
||||
},
|
||||
],
|
||||
max_tokens=_MAX_DESCRIPTION_TOKENS,
|
||||
)
|
||||
optimized = (response.choices[0].message.content or "").strip()
|
||||
if optimized:
|
||||
results[block_id] = optimized
|
||||
logger.debug("Optimized description for %s", block_name)
|
||||
else:
|
||||
logger.warning("Empty response for block %s", block_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to optimize description for %s", block_name, exc_info=True
|
||||
)
|
||||
|
||||
await asyncio.sleep(_RATE_LIMIT_DELAY)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def optimize_block_descriptions() -> dict[str, int]:
|
||||
"""Generate optimized descriptions for blocks that don't have one yet.
|
||||
|
||||
Uses the shared OpenAI client to rewrite block descriptions into concise
|
||||
summaries suitable for agent generation prompts.
|
||||
|
||||
Returns:
|
||||
Dict with counts: processed, success, failed, skipped.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
|
||||
blocks = db_client.get_blocks_needing_optimization()
|
||||
if not blocks:
|
||||
logger.info("All blocks already have optimized descriptions")
|
||||
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
|
||||
|
||||
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
|
||||
|
||||
non_empty = [b for b in blocks if b.get("description", "").strip()]
|
||||
skipped = len(blocks) - len(non_empty)
|
||||
|
||||
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
|
||||
|
||||
stats = {
|
||||
"processed": len(non_empty),
|
||||
"success": len(new_descriptions),
|
||||
"failed": len(non_empty) - len(new_descriptions),
|
||||
"skipped": skipped,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Block description optimization complete: "
|
||||
"%d/%d succeeded, %d failed, %d skipped",
|
||||
stats["success"],
|
||||
stats["processed"],
|
||||
stats["failed"],
|
||||
stats["skipped"],
|
||||
)
|
||||
|
||||
if new_descriptions:
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
db_client.update_block_optimized_description(block_id, optimized)
|
||||
|
||||
# Update in-memory descriptions first so the cache rebuilds with fresh data.
|
||||
try:
|
||||
block_classes = get_blocks()
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
if block_id in block_classes:
|
||||
block_classes[block_id]._optimized_description = optimized
|
||||
logger.info(
|
||||
"Updated %d in-memory block descriptions", len(new_descriptions)
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not update in-memory block descriptions", exc_info=True
|
||||
)
|
||||
|
||||
from backend.copilot.tools.agent_generator.blocks import (
|
||||
reset_block_caches, # local to avoid circular import
|
||||
)
|
||||
|
||||
reset_block_caches()
|
||||
|
||||
return stats
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Unit tests for optimize_blocks._optimize_descriptions."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
|
||||
|
||||
|
||||
def _make_client_response(text: str) -> MagicMock:
|
||||
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
|
||||
choice = MagicMock()
|
||||
choice.message.content = text
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
return response
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
class TestOptimizeDescriptions:
|
||||
"""Tests for _optimize_descriptions async function."""
|
||||
|
||||
def test_returns_empty_when_no_client(self):
|
||||
with patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
|
||||
):
|
||||
result = _run(
|
||||
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
def test_success_single_block(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("Short desc.")
|
||||
)
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {"b1": "Short desc."}
|
||||
client.chat.completions.create.assert_called_once()
|
||||
|
||||
def test_skips_block_on_exception(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_sleeps_between_blocks(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("desc")
|
||||
)
|
||||
blocks = [
|
||||
{"id": "b1", "name": "B1", "description": "d1"},
|
||||
{"id": "b2", "name": "B2", "description": "d2"},
|
||||
]
|
||||
sleep_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
|
||||
):
|
||||
_run(_optimize_descriptions(blocks))
|
||||
|
||||
assert sleep_mock.call_count == 2
|
||||
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)
|
||||
255
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
255
autogpt_platform/backend/backend/copilot/prompting.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Centralized prompt building logic for CoPilot.
|
||||
|
||||
This module contains all prompt construction functions and constants,
|
||||
handling the distinction between:
|
||||
- SDK mode vs Baseline mode (tool documentation needs)
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
working_dir: str,
|
||||
sandbox_type: str,
|
||||
storage_system_1_name: str,
|
||||
storage_system_1_characteristics: list[str],
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
Template function handles all formatting (bullets, indentation, markdown).
|
||||
Callers provide clean data as lists of strings.
|
||||
|
||||
Args:
|
||||
working_dir: Working directory path
|
||||
sandbox_type: Description of bash_exec sandbox
|
||||
storage_system_1_name: Name of primary storage (ephemeral or cloud)
|
||||
storage_system_1_characteristics: List of characteristic descriptions
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
|
||||
|
||||
return f"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
storage_system_1_name="Ephemeral working directory",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files here are **lost between turns** — do NOT rely on them persisting",
|
||||
"Use for temporary work: running scripts, processing data, etc.",
|
||||
],
|
||||
file_move_name_1_to_2="Ephemeral → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Ephemeral",
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
storage_system_1_name="Cloud sandbox",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by all file tools AND `bash_exec` — same filesystem",
|
||||
"Full Linux environment with internet access",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files **persist across turns** within the current session",
|
||||
"Lost when the session expires (12 h inactivity)",
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
SDK mode doesn't need it since Claude gets tool schemas automatically.
|
||||
|
||||
This generates a complete list of available tools with their descriptions,
|
||||
ensuring the documentation stays in sync with the actual tool implementations.
|
||||
All workflow guidance is now embedded in individual tool descriptions.
|
||||
|
||||
Only documents tools that are available in the current environment
|
||||
(checked via tool.is_available property).
|
||||
"""
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
docs += f"- **`{name}`**: {desc}\n"
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
"""Get the supplement for SDK mode (Claude Agent SDK).
|
||||
|
||||
SDK mode does NOT include tool documentation because Claude automatically
|
||||
receives tool schemas from the SDK. Only includes technical notes about
|
||||
storage systems and execution environment.
|
||||
|
||||
Args:
|
||||
use_e2b: Whether E2B cloud sandbox is being used
|
||||
cwd: Current working directory (only used in local_storage mode)
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
if use_e2b:
|
||||
return _get_cloud_sandbox_supplement()
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
direct API doesn't automatically provide tool schemas to Claude.
|
||||
Also includes shared technical notes (but NOT SDK-specific environment details).
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
@@ -3,12 +3,45 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
## Agent Generation Guide
|
||||
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "<UUID v4>", // auto-generated if omitted
|
||||
"version": 1,
|
||||
"is_active": true,
|
||||
"name": "Agent Name",
|
||||
"description": "What the agent does",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"block_id": "<block UUID from find_block>",
|
||||
"input_default": {
|
||||
"field_name": "design-time value"
|
||||
},
|
||||
"metadata": {
|
||||
"position": {"x": 0, "y": 0},
|
||||
"customized_name": "Optional display name"
|
||||
}
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"source_id": "<source node UUID>",
|
||||
"source_name": "output_field_name",
|
||||
"sink_id": "<sink node UUID>",
|
||||
"sink_name": "input_field_name",
|
||||
"is_static": false
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### REQUIRED: AgentInputBlock and AgentOutputBlock
|
||||
|
||||
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
|
||||
These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||
- Defines a user-facing input field on the agent
|
||||
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||
- Output: `result` — the user-provided value at runtime
|
||||
- Create one AgentInputBlock per distinct input the agent needs
|
||||
|
||||
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||
- Defines a user-facing output displayed after the agent runs
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- The `value` input should be linked from another block's output
|
||||
- Optional: `title`, `description`, `format` (Jinja2 template)
|
||||
- Create one AgentOutputBlock per distinct result to show the user
|
||||
|
||||
Without these blocks, the agent has no interface and the user cannot provide
|
||||
inputs or see outputs. NEVER skip them.
|
||||
|
||||
### Key Rules
|
||||
|
||||
- **Name & description**: Include `name` and `description` in the agent JSON
|
||||
when creating a new agent, or when editing and the agent's purpose changed.
|
||||
Without these the agent gets a generic default name.
|
||||
- **Design-time vs runtime**: `input_default` = values known at build time.
|
||||
For user-provided values, create an `AgentInputBlock` node and link its
|
||||
output to the consuming block's input.
|
||||
- **Credentials**: Do NOT require credentials upfront. Users configure
|
||||
credentials later in the platform UI after the agent is saved.
|
||||
- **Node spacing**: Position nodes with at least 800 X-units between them.
|
||||
- **Nested properties**: Use `parentField_#_childField` notation in link
|
||||
sink_name/source_name to access nested object fields.
|
||||
- **is_static links**: Set `is_static: true` when the link carries a
|
||||
design-time constant (matches a field in inputSchema with a default).
|
||||
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
|
||||
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
|
||||
literal braces in prompt strings — single `{` and `}` are for
|
||||
template variables.
|
||||
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
|
||||
`graph_version` in input_default, and wire inputs/outputs to match
|
||||
the sub-agent's schema.
|
||||
|
||||
### Using Sub-Agents (AgentExecutorBlock)
|
||||
|
||||
To compose agents using other agents as sub-agents:
|
||||
1. Call `find_library_agent` to find the sub-agent — the response includes
|
||||
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
|
||||
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
|
||||
3. Set `input_default`:
|
||||
- `graph_id`: from the library agent's `graph_id`
|
||||
- `graph_version`: from the library agent's `graph_version`
|
||||
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
|
||||
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
|
||||
- `user_id`: leave as `""` (filled at runtime)
|
||||
- `inputs`: `{}` (populated by links at runtime)
|
||||
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
|
||||
property names (e.g., if input_schema has a `"url"` property, use
|
||||
`"url"` as the sink_name)
|
||||
5. Wire outputs: link from source names matching the sub-agent's
|
||||
`output_schema` property names
|
||||
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
|
||||
the library agent IDs used, so the fixer can validate schemas
|
||||
|
||||
### Using MCP Tools (MCPToolBlock)
|
||||
|
||||
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
1. The user must specify which MCP server URL and tool name they want
|
||||
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||
3. Set `input_default`:
|
||||
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
|
||||
- `selected_tool`: the tool name on that server
|
||||
- `tool_input_schema`: JSON Schema for the tool's inputs
|
||||
- `tool_arguments`: `{}` (populated by links or hardcoded values)
|
||||
4. The block requires MCP credentials — the user configures these in the
|
||||
platform UI after the agent is saved
|
||||
5. Wire inputs using the tool argument field name directly as the sink_name
|
||||
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
|
||||
automatically collects top-level fields matching tool_input_schema into
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
|
||||
input_default: {"name": "user_text", "title": "Text to process"},
|
||||
output: "result")
|
||||
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
|
||||
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
|
||||
input_default: {"name": "summary", "title": "Summary"},
|
||||
input: "value" linked from Node 2's output)
|
||||
@@ -11,7 +11,7 @@ persistence, and the ``CompactionTracker`` state machine.
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -27,6 +27,19 @@ from ..response_model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompactionResult:
|
||||
"""Result of emit_end_if_ready — bundles events with compaction metadata.
|
||||
|
||||
Eliminates the need for separate ``compaction_just_ended`` checks,
|
||||
preventing TOCTOU races between the emit call and the flag read.
|
||||
"""
|
||||
|
||||
events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
just_ended: bool = False
|
||||
transcript_path: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders (private — use CompactionTracker or compaction_events)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -177,11 +190,22 @@ class CompactionTracker:
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path: str = ""
|
||||
|
||||
@property
|
||||
def on_compact(self) -> Callable[[], None]:
|
||||
"""Callback for the PreCompact hook."""
|
||||
return self._compact_start.set
|
||||
def on_compact(self, transcript_path: str = "") -> None:
|
||||
"""Callback for the PreCompact hook. Stores transcript_path."""
|
||||
if (
|
||||
self._transcript_path
|
||||
and transcript_path
|
||||
and self._transcript_path != transcript_path
|
||||
):
|
||||
logger.warning(
|
||||
"[Compaction] Overwriting transcript_path %s -> %s",
|
||||
self._transcript_path,
|
||||
transcript_path,
|
||||
)
|
||||
self._transcript_path = transcript_path
|
||||
self._compact_start.set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
@@ -201,6 +225,7 @@ class CompactionTracker:
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
@@ -211,15 +236,20 @@ class CompactionTracker:
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""If compaction is in progress, emit end events and persist."""
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> CompactionResult:
|
||||
"""If compaction is in progress, emit end events and persist.
|
||||
|
||||
Returns a ``CompactionResult`` with ``just_ended=True`` and the
|
||||
captured ``transcript_path`` when a compaction cycle completes.
|
||||
This avoids a separate flag check (TOCTOU-safe).
|
||||
"""
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return []
|
||||
return CompactionResult()
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return []
|
||||
return CompactionResult()
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
@@ -232,8 +262,12 @@ class CompactionTracker:
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
|
||||
transcript_path = self._transcript_path
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
self._transcript_path = ""
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return done_events
|
||||
return CompactionResult(
|
||||
events=done_events, just_ended=True, transcript_path=transcript_path
|
||||
)
|
||||
|
||||
@@ -195,10 +195,11 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 2
|
||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[1], StreamFinishStep)
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert len(result.events) == 2
|
||||
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(result.events[1], StreamFinishStep)
|
||||
# Should persist
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@@ -210,28 +211,32 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
# Don't call emit_start_if_ready
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 5 # Full self-contained event
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert len(result.events) == 5 # Full self-contained event
|
||||
assert isinstance(result.events[0], StreamStartStep)
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_done(self):
|
||||
async def test_emit_end_no_op_when_no_new_compaction(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# Second call should be no-op
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
# Second call should be no-op (no new on_compact)
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
assert result2.events == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is False
|
||||
assert result.events == []
|
||||
|
||||
def test_emit_pre_query(self):
|
||||
tracker = CompactionTracker()
|
||||
@@ -246,20 +251,29 @@ class TestCompactionTracker:
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker._transcript_path = "/some/path"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction(self):
|
||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
||||
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
|
||||
"""After pre-query compaction, SDK compaction is blocked until
|
||||
reset_for_query is called."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
# _done is True so emit_start_if_ready is blocked
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == [] # _done blocks it
|
||||
assert evts == []
|
||||
# Reset clears _done, allowing subsequent compaction
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
@@ -279,9 +293,9 @@ class TestCompactionTracker:
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
end_evts = await tracker.emit_end_if_ready(session)
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
start_evt = start_evts[1]
|
||||
end_evt = end_evts[0]
|
||||
end_evt = result.events[0]
|
||||
assert isinstance(start_evt, StreamToolInputStart)
|
||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||
assert start_evt.toolCallId == end_evt.toolCallId
|
||||
@@ -289,3 +303,105 @@ class TestCompactionTracker:
|
||||
tool_calls = session.messages[0].tool_calls
|
||||
assert tool_calls is not None
|
||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_within_query(self):
|
||||
"""Two mid-stream compactions within a single query both trigger."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
# First compaction cycle
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert len(result1.events) == 2
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Second compaction cycle (should NOT be blocked — _done resets
|
||||
# because emit_end_if_ready sets it True, but the next on_compact
|
||||
# + emit_start_if_ready checks !_done which IS True now.
|
||||
# So we need reset_for_query between queries, but within a single
|
||||
# query multiple compactions work because _done blocks emit_start
|
||||
# until the next message arrives, at which point emit_end detects it)
|
||||
#
|
||||
# Actually: _done=True blocks emit_start_if_ready, so we need
|
||||
# the stream loop to reset. In practice service.py doesn't call
|
||||
# reset between compactions within the same query — let's verify
|
||||
# the actual behavior.
|
||||
tracker.on_compact("/path/2")
|
||||
# _done is True from first compaction, so start is blocked
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert start_evts == []
|
||||
# But emit_end returns no-op because _done is True
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_with_intervening_message(self):
|
||||
"""Multiple compactions work when the stream loop processes messages between them.
|
||||
|
||||
In the real service.py flow:
|
||||
1. PreCompact fires → on_compact()
|
||||
2. emit_start shows spinner
|
||||
3. Next message arrives → emit_end completes compaction (_done=True)
|
||||
4. Stream continues processing messages...
|
||||
5. If a second PreCompact fires, _done=True blocks emit_start
|
||||
6. But the next message triggers emit_end, which sees _done=True → no-op
|
||||
7. The stream loop needs to detect this and handle accordingly
|
||||
|
||||
The actual flow for multiple compactions within a query requires
|
||||
_done to be cleared between them. The service.py code uses
|
||||
CompactionResult.just_ended to trigger replace_entries, and _done
|
||||
stays True until reset_for_query.
|
||||
"""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
# First compaction
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Simulate reset between queries
|
||||
tracker.reset_for_query()
|
||||
|
||||
# Second compaction in new query
|
||||
tracker.on_compact("/path/2")
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert len(start_evts) == 3
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
def test_on_compact_stores_transcript_path(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact("/some/path.jsonl")
|
||||
assert tracker._transcript_path == "/some/path.jsonl"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_returns_transcript_path(self):
|
||||
"""CompactionResult includes the transcript_path from on_compact."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/my/session.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == "/my/session.jsonl"
|
||||
# transcript_path is cleared after emit_end
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_clears_transcript_path(self):
|
||||
"""After emit_end, _transcript_path is reset so it doesn't leak to
|
||||
subsequent non-compaction emit_end calls."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/first/path.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# After compaction, _transcript_path is cleared
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@@ -10,6 +10,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
@@ -26,6 +27,7 @@ async def stream_chat_completion_dummy(
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -17,36 +15,23 @@ import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Lazy imports to break circular dependency with tool_adapter.
|
||||
|
||||
|
||||
def _get_sandbox(): # type: ignore[return]
|
||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
||||
|
||||
return is_allowed_local_path(path)
|
||||
|
||||
|
||||
def _resolve_remote(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
return is_allowed_local_path(path, get_sdk_cwd())
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = _resolve_remote(file_path)
|
||||
remote = resolve_sandbox_path(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Write content to a sandbox file, creating parent directories as needed."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Replace a substring in a sandbox file, with optional replace-all support."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Find files matching a name pattern inside the sandbox using ``find``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded) as fh:
|
||||
with open(expanded, encoding="utf-8", errors="replace") as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
@@ -280,7 +270,9 @@ E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically.",
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -7,59 +7,60 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .e2b_file_tools import _read_local, _resolve_remote
|
||||
from .tool_adapter import _current_project_dir
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRemote:
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert _resolve_remote("/home/user") == "/home/user"
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("../../etc/passwd")
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/user/../../etc/passwd")
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/etc/passwd")
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/")
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
_resolve_remote("/home/other/file.txt")
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert _resolve_remote("src/") == "/home/user/src"
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,531 @@
|
||||
"""End-to-end compaction flow test.
|
||||
|
||||
Simulates the full service.py compaction lifecycle using real-format
|
||||
JSONL session files — no SDK subprocess needed. Exercises:
|
||||
|
||||
1. TranscriptBuilder loads a "downloaded" transcript
|
||||
2. User query appended, assistant response streamed
|
||||
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||
4. Next message → emit_start_if_ready() yields spinner events
|
||||
5. Message after that → emit_end_if_ready() returns CompactionResult
|
||||
6. read_compacted_entries() reads the CLI session file
|
||||
7. TranscriptBuilder.replace_entries() syncs state
|
||||
8. More messages appended post-compaction
|
||||
9. to_jsonl() exports full state for upload
|
||||
10. Fresh builder loads the export — roundtrip verified
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import CompactionTracker
|
||||
from backend.copilot.sdk.transcript import (
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
)
|
||||
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic CLI session file content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Pre-compaction conversation
|
||||
USER_1 = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"message": {"role": "user", "content": "What files are in this project?"},
|
||||
}
|
||||
ASST_1_THINKING = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-think",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
ASST_1_TOOL = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-tool",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_aaa",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu1",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"},
|
||||
}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
TOOL_RESULT_1 = {
|
||||
"type": "user",
|
||||
"uuid": "tr1",
|
||||
"parentUuid": "a1-tool",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tu1",
|
||||
"content": "file1.py\nfile2.py",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
ASST_1_TEXT = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1-text",
|
||||
"parentUuid": "tr1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_bbb",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
# Progress entries (should be stripped during upload)
|
||||
PROGRESS_1 = {
|
||||
"type": "progress",
|
||||
"uuid": "prog1",
|
||||
"parentUuid": "a1-tool",
|
||||
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||
}
|
||||
# Second user message
|
||||
USER_2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1-text",
|
||||
"message": {"role": "user", "content": "Show me file1.py"},
|
||||
}
|
||||
ASST_2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ccc",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Compaction summary (written by CLI after context compaction) ---
|
||||
COMPACT_SUMMARY = {
|
||||
"type": "summary",
|
||||
"uuid": "cs1",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||
"User then asked to see file1.py."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction assistant response
|
||||
POST_COMPACT_ASST = {
|
||||
"type": "assistant",
|
||||
"uuid": "a3",
|
||||
"parentUuid": "cs1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_ddd",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Post-compaction user follow-up
|
||||
USER_3 = {
|
||||
"type": "user",
|
||||
"uuid": "u3",
|
||||
"parentUuid": "a3",
|
||||
"message": {"role": "user", "content": "Now show file2.py"},
|
||||
}
|
||||
ASST_3 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a4",
|
||||
"parentUuid": "u3",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_sdk_eee",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionE2E:
|
||||
def _write_session_file(self, session_dir, entries):
|
||||
"""Write a CLI session JSONL file."""
|
||||
path = session_dir / "session.jsonl"
|
||||
path.write_text(_make_jsonl(*entries))
|
||||
return path
|
||||
|
||||
def test_full_compaction_lifecycle(self, tmp_path, monkeypatch):
|
||||
"""Simulate the complete service.py compaction flow.
|
||||
|
||||
Timeline:
|
||||
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||
2. Current turn: download → load_previous
|
||||
3. User sends "Now show file2.py" → append_user
|
||||
4. SDK starts streaming response
|
||||
5. Mid-stream: PreCompact hook fires (context too large)
|
||||
6. CLI writes compaction summary to session file
|
||||
7. Next SDK message → emit_start (spinner)
|
||||
8. Following message → emit_end (CompactionResult)
|
||||
9. read_compacted_entries reads the session file
|
||||
10. replace_entries syncs TranscriptBuilder
|
||||
11. More assistant messages appended
|
||||
12. Export → upload → next turn downloads it
|
||||
"""
|
||||
# --- Setup CLI projects directory ---
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||
previous_transcript = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(previous_transcript)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
# --- Step 3: User sends new query ---
|
||||
builder.append_user("Now show file2.py")
|
||||
assert builder.entry_count == 8
|
||||
|
||||
# --- Step 4: SDK starts streaming ---
|
||||
builder.append_assistant(
|
||||
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
assert builder.entry_count == 9
|
||||
|
||||
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
USER_3,
|
||||
ASST_3,
|
||||
],
|
||||
)
|
||||
|
||||
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test-user")
|
||||
tracker.on_compact(str(session_file))
|
||||
|
||||
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||
start_events = tracker.emit_start_if_ready()
|
||||
assert len(start_events) == 3
|
||||
assert isinstance(start_events[0], StreamStartStep)
|
||||
assert isinstance(start_events[1], StreamToolInputStart)
|
||||
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||
|
||||
# Verify tool_call_id is set
|
||||
tool_call_id = start_events[1].toolCallId
|
||||
assert tool_call_id.startswith("compaction-")
|
||||
|
||||
# --- Step 9: Following message → emit_end ---
|
||||
result = _run(tracker.emit_end_if_ready(session))
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == str(session_file)
|
||||
assert len(result.events) == 2
|
||||
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||
assert isinstance(result.events[1], StreamFinishStep)
|
||||
# Verify same tool_call_id
|
||||
assert result.events[0].toolCallId == tool_call_id
|
||||
|
||||
# Session should have compaction messages persisted
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[1].role == "tool"
|
||||
|
||||
# --- Step 10: read_compacted_entries + replace_entries ---
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||
assert len(compacted) == 4
|
||||
assert compacted[0]["uuid"] == "cs1"
|
||||
assert compacted[0]["isCompactSummary"] is True
|
||||
|
||||
# Replace builder state with compacted entries
|
||||
old_count = builder.entry_count
|
||||
builder.replace_entries(compacted)
|
||||
assert builder.entry_count == 4 # Only compacted entries
|
||||
assert builder.entry_count < old_count # Compaction reduced entries
|
||||
|
||||
# --- Step 11: More assistant messages after compaction ---
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
assert builder.entry_count == 5
|
||||
|
||||
# --- Step 12: Export for upload ---
|
||||
output = builder.to_jsonl()
|
||||
assert output # Not empty
|
||||
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert len(output_entries) == 5
|
||||
|
||||
# Verify structure:
|
||||
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||
assert output_entries[0]["type"] == "summary"
|
||||
assert output_entries[0].get("isCompactSummary") is True
|
||||
assert output_entries[0]["uuid"] == "cs1"
|
||||
assert output_entries[1]["uuid"] == "a3"
|
||||
assert output_entries[2]["uuid"] == "u3"
|
||||
assert output_entries[3]["uuid"] == "a4"
|
||||
assert output_entries[4]["type"] == "assistant"
|
||||
|
||||
# Verify parent chain is intact
|
||||
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||
|
||||
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||
builder2 = TranscriptBuilder()
|
||||
builder2.load_previous(output)
|
||||
assert builder2.entry_count == 5
|
||||
|
||||
# isCompactSummary survives roundtrip
|
||||
output2 = builder2.to_jsonl()
|
||||
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||
assert first_entry.get("isCompactSummary") is True
|
||||
|
||||
# Can append more messages
|
||||
builder2.append_user("What about file3.py?")
|
||||
assert builder2.entry_count == 6
|
||||
final_output = builder2.to_jsonl()
|
||||
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||
assert last_entry["type"] == "user"
|
||||
# Parented to the last entry from previous turn
|
||||
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||
|
||||
def test_double_compaction_within_session(self, tmp_path, monkeypatch):
|
||||
"""Two compactions in the same session (across reset_for_query)."""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
tracker = CompactionTracker()
|
||||
session = ChatSession.new(user_id="test")
|
||||
builder = TranscriptBuilder()
|
||||
|
||||
# --- First query with compaction ---
|
||||
builder.append_user("first question")
|
||||
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||
|
||||
# Write session file for first compaction
|
||||
first_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-first",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "First compaction summary"},
|
||||
}
|
||||
first_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-first",
|
||||
"parentUuid": "cs-first",
|
||||
"message": {"role": "assistant", "content": "first post-compact"},
|
||||
}
|
||||
file1 = session_dir / "session1.jsonl"
|
||||
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||
|
||||
tracker.on_compact(str(file1))
|
||||
tracker.emit_start_if_ready()
|
||||
result1 = _run(tracker.emit_end_if_ready(session))
|
||||
assert result1.just_ended is True
|
||||
|
||||
compacted1 = read_compacted_entries(str(file1))
|
||||
assert compacted1 is not None
|
||||
builder.replace_entries(compacted1)
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# --- Reset for second query ---
|
||||
tracker.reset_for_query()
|
||||
|
||||
# --- Second query with compaction ---
|
||||
builder.append_user("second question")
|
||||
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||
|
||||
second_summary = {
|
||||
"type": "summary",
|
||||
"uuid": "cs-second",
|
||||
"isCompactSummary": True,
|
||||
"message": {"role": "user", "content": "Second compaction summary"},
|
||||
}
|
||||
second_post = {
|
||||
"type": "assistant",
|
||||
"uuid": "a-second",
|
||||
"parentUuid": "cs-second",
|
||||
"message": {"role": "assistant", "content": "second post-compact"},
|
||||
}
|
||||
file2 = session_dir / "session2.jsonl"
|
||||
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||
|
||||
tracker.on_compact(str(file2))
|
||||
tracker.emit_start_if_ready()
|
||||
result2 = _run(tracker.emit_end_if_ready(session))
|
||||
assert result2.just_ended is True
|
||||
|
||||
compacted2 = read_compacted_entries(str(file2))
|
||||
assert compacted2 is not None
|
||||
builder.replace_entries(compacted2)
|
||||
assert builder.entry_count == 2 # Only second compaction entries
|
||||
|
||||
# Export and verify
|
||||
output = builder.to_jsonl()
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
assert entries[0]["uuid"] == "cs-second"
|
||||
assert entries[0].get("isCompactSummary") is True
|
||||
|
||||
def test_strip_progress_then_load_then_compact_roundtrip(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||
|
||||
This tests the exact sequence that happens across two turns:
|
||||
Turn 1: SDK produces transcript with progress entries
|
||||
Upload: strip_progress_entries removes progress, upload to cloud
|
||||
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||
"""
|
||||
config_dir = tmp_path / "config"
|
||||
projects_dir = config_dir / "projects"
|
||||
session_dir = projects_dir / "proj"
|
||||
session_dir.mkdir(parents=True)
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# --- Turn 1: SDK produces raw transcript ---
|
||||
raw_content = _make_jsonl(
|
||||
USER_1,
|
||||
ASST_1_THINKING,
|
||||
ASST_1_TOOL,
|
||||
PROGRESS_1,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
)
|
||||
|
||||
# Strip progress for upload
|
||||
stripped = strip_progress_entries(raw_content)
|
||||
stripped_entries = [
|
||||
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||
]
|
||||
# Progress should be gone
|
||||
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||
|
||||
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||
builder = TranscriptBuilder()
|
||||
builder.load_previous(stripped)
|
||||
assert builder.entry_count == 7
|
||||
|
||||
builder.append_user("Now show file2.py")
|
||||
builder.append_assistant(
|
||||
[{"type": "text", "text": "Reading file2.py..."}],
|
||||
model="claude-sonnet-4-20250514",
|
||||
)
|
||||
|
||||
# CLI writes session file with compaction
|
||||
session_file = self._write_session_file(
|
||||
session_dir,
|
||||
[
|
||||
USER_1,
|
||||
ASST_1_TOOL,
|
||||
TOOL_RESULT_1,
|
||||
ASST_1_TEXT,
|
||||
USER_2,
|
||||
ASST_2,
|
||||
COMPACT_SUMMARY,
|
||||
POST_COMPACT_ASST,
|
||||
],
|
||||
)
|
||||
|
||||
compacted = read_compacted_entries(str(session_file))
|
||||
assert compacted is not None
|
||||
builder.replace_entries(compacted)
|
||||
|
||||
# Append post-compaction message
|
||||
builder.append_user("Thanks!")
|
||||
output = builder.to_jsonl()
|
||||
|
||||
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||
builder3 = TranscriptBuilder()
|
||||
builder3.load_previous(output)
|
||||
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||
assert builder3.entry_count == 3
|
||||
|
||||
# Compact summary survived the full pipeline
|
||||
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||
assert first.get("isCompactSummary") is True
|
||||
assert first["type"] == "summary"
|
||||
715
autogpt_platform/backend/backend/copilot/sdk/file_ref.py
Normal file
715
autogpt_platform/backend/backend/copilot/sdk/file_ref.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""File reference protocol for tool call inputs.
|
||||
|
||||
Allows the LLM to pass a file reference instead of embedding large content
|
||||
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
|
||||
arguments before the tool is executed.
|
||||
|
||||
Protocol
|
||||
--------
|
||||
|
||||
@@agptfile:<uri>[<start>-<end>]
|
||||
|
||||
``<uri>`` (required)
|
||||
- ``workspace://<file_id>`` — workspace file by ID
|
||||
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
|
||||
- ``workspace:///<path>`` — workspace file by virtual path
|
||||
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
|
||||
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
|
||||
- Any absolute path that resolves inside the E2B sandbox
|
||||
(``/home/user/...``) when a sandbox is active
|
||||
|
||||
``[<start>-<end>]`` (optional)
|
||||
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
|
||||
Omit to read the entire file.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.sh
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
|
||||
|
||||
Separating this from inline substitution lets callers (e.g. the MCP tool
|
||||
wrapper) block tool execution and surface a helpful error to the model
|
||||
rather than passing an ``[file-ref error: …]`` string as actual input.
|
||||
"""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_REF_PREFIX = "@@agptfile:"
|
||||
|
||||
# Matches: @@agptfile:<uri>[start-end]?
|
||||
# Group 1 – URI; must start with '/' (absolute path) or 'workspace://'
|
||||
# Group 2 – start line (optional)
|
||||
# Group 3 – end line (optional)
|
||||
_FILE_REF_RE = re.compile(
|
||||
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
|
||||
)
|
||||
|
||||
# Maximum characters returned for a single file reference expansion.
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileRef:
|
||||
uri: str
|
||||
start_line: int | None # 1-indexed, inclusive
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
|
||||
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
|
||||
expand references embedded in larger strings.
|
||||
"""
|
||||
m = _FILE_REF_RE.fullmatch(text.strip())
|
||||
if not m:
|
||||
return None
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if start is not None and start < 1:
|
||||
return None
|
||||
if end is not None and end < 1:
|
||||
return None
|
||||
if start is not None and end is not None and end < start:
|
||||
return None
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> bytes:
|
||||
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
|
||||
|
||||
Raises :class:`ValueError` if the URI cannot be resolved.
|
||||
"""
|
||||
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
|
||||
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
|
||||
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
try:
|
||||
remote = resolve_sandbox_path(plain)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_file_ref(
|
||||
ref: FileRef,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
|
||||
|
||||
Non-reference text is passed through unchanged.
|
||||
|
||||
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
|
||||
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
|
||||
where partial expansion is acceptable.
|
||||
|
||||
If *raise_on_error* is ``True``, any resolution failure raises
|
||||
:class:`FileRefExpansionError` immediately so the caller can block the
|
||||
operation and surface a clean error to the model.
|
||||
"""
|
||||
if FILE_REF_PREFIX not in text:
|
||||
return text
|
||||
|
||||
result: list[str] = []
|
||||
last_end = 0
|
||||
total_chars = 0
|
||||
for m in _FILE_REF_RE.finditer(text):
|
||||
result.append(text[last_end : m.start()])
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if (start is not None and start < 1) or (end is not None and end < 1):
|
||||
msg = f"line numbers must be >= 1: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
if start is not None and end is not None and end < start:
|
||||
msg = f"end line must be >= start line: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
try:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
content = content[:remaining] + "\n... [total budget exhausted]"
|
||||
total_chars += len(content)
|
||||
result.append(content)
|
||||
except ValueError as exc:
|
||||
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
result.append(f"[file-ref error: {exc}]")
|
||||
last_end = m.end()
|
||||
|
||||
result.append(text[last_end:])
|
||||
return "".join(result)
|
||||
|
||||
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
response that lets the model correct the reference before retrying.
|
||||
"""
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
"""Integration tests for @@agptfile: reference expansion in tool calls.
|
||||
|
||||
These tests verify the end-to-end behaviour of the file reference protocol:
|
||||
- Parsing @@agptfile: tokens from tool arguments
|
||||
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
|
||||
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
|
||||
- The extended Read tool handler (workspace:// pass-through via session context)
|
||||
|
||||
No real LLM or database is required; workspace reads are stubbed where needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
FileRef,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
read_file_bytes,
|
||||
resolve_file_ref,
|
||||
)
|
||||
from backend.copilot.sdk.tool_adapter import _read_file_handler
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "integ-sess") -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = session_id
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local-file resolution (sdk_cwd)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path():
|
||||
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
# Write a test file inside sdk_cwd
|
||||
test_file = os.path.join(sdk_cwd, "hello.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("line1\nline2\nline3\n")
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=None, end_line=None)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path_with_line_range():
|
||||
"""resolve_file_ref respects line ranges for local files."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "multi.txt")
|
||||
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=3, end_line=5)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
|
||||
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_sandbox_var.get.return_value = None
|
||||
|
||||
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await resolve_file_ref(ref, user_id="u1", session=_make_session())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string — integration with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_with_real_file():
|
||||
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "data.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("hello world\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"Content: @@agptfile:{test_file}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result == "Content: hello world\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_missing_file_is_surfaced_inline():
|
||||
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"@@agptfile:{missing}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "[file-ref error:" in result
|
||||
assert "not found" in result.lower() or "not allowed" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — dict traversal with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
"""Nested @@agptfile: references in args are fully expanded."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
file_a = os.path.join(sdk_cwd, "a.txt")
|
||||
file_b = os.path.join(sdk_cwd, "b.txt")
|
||||
with open(file_a, "w") as f:
|
||||
f.write("AAA")
|
||||
with open(file_b, "w") as f:
|
||||
f.write("BBB")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"outer": {
|
||||
"content_a": f"@@agptfile:{file_a}",
|
||||
"content_b": f"start @@agptfile:{file_b} end",
|
||||
},
|
||||
"count": 42,
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["outer"]["content_a"] == "AAA"
|
||||
assert result["outer"]["content_b"] == "start BBB end"
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_local_file():
|
||||
"""_read_file_handler reads a local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "read_test.txt")
|
||||
lines = [f"L{i}\n" for i in range(1, 6)]
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj_var, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_proj_var.get.return_value = ""
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": test_file, "offset": 0, "limit": 5}
|
||||
)
|
||||
|
||||
assert not result["isError"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "L1" in text
|
||||
assert "L5" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri():
|
||||
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
|
||||
mock_session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
|
||||
)
|
||||
|
||||
assert not result["isError"], result["content"][0]["text"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "workspace file content" in text
|
||||
assert "line two" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri_no_session():
|
||||
"""_read_file_handler returns error when workspace:// is used without session."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = await _read_file_handler({"file_path": "workspace://some-id"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_access_denied():
|
||||
"""_read_file_handler rejects paths outside allowed locations."""
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd.get.return_value = "/tmp/safe-dir"
|
||||
mock_sandbox.get.return_value = None
|
||||
|
||||
result = await _read_file_handler({"file_path": "/etc/passwd"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_workspace_virtual_path():
|
||||
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
|
||||
session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
assert result == b"virtual path content"
|
||||
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_sandbox_branch():
|
||||
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
result = await read_file_bytes("/home/user/script.sh", None, session)
|
||||
|
||||
assert result == b"sandbox content"
|
||||
mock_sandbox.files.read.assert_awaited_once_with(
|
||||
"/home/user/script.sh", format="bytes"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
|
||||
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await read_file_bytes("/etc/passwd", None, session)
|
||||
1979
autogpt_platform/backend/backend/copilot/sdk/file_ref_test.py
Normal file
1979
autogpt_platform/backend/backend/copilot/sdk/file_ref_test.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user