mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
56 Commits
swiftyos/i
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8892bcd230 | ||
|
|
48ff8300a4 | ||
|
|
c268fc6464 | ||
|
|
aff3fb44af | ||
|
|
9a41312769 | ||
|
|
048fb06b0a | ||
|
|
3f653e6614 | ||
|
|
c9c3d54b2b | ||
|
|
53d58e21d3 | ||
|
|
fa04fb41d8 | ||
|
|
d9c16ded65 | ||
|
|
6dc8429ae7 | ||
|
|
cfe22e5a8f | ||
|
|
0b594a219c | ||
|
|
a8259ca935 | ||
|
|
1f1288d623 | ||
|
|
02645732b8 | ||
|
|
ba301a3912 | ||
|
|
0cd9c0d87a | ||
|
|
a083493aa2 | ||
|
|
c51dc7ad99 | ||
|
|
bc6b82218a | ||
|
|
83e49f71cd | ||
|
|
ef446e4fe9 | ||
|
|
7b1e8ed786 | ||
|
|
7ccfff1040 | ||
|
|
81c7685a82 | ||
|
|
3595c6e769 | ||
|
|
1c2953d61b | ||
|
|
755bc84b1a | ||
|
|
ade2baa58f | ||
|
|
4d35534a89 | ||
|
|
2cc748f34c | ||
|
|
c2e79fa5e1 | ||
|
|
89a5b3178a | ||
|
|
c62d9a24ff | ||
|
|
0e0bfaac29 | ||
|
|
0633475915 | ||
|
|
34a2f9a0a2 | ||
|
|
9f4caa7dfc | ||
|
|
0876d22e22 | ||
|
|
15e3980d65 | ||
|
|
fe9eb2564b | ||
|
|
5641cdd3ca | ||
|
|
bfb843a56e | ||
|
|
684845d946 | ||
|
|
6a6b23c2e1 | ||
|
|
d0a1d72e8a | ||
|
|
f1945d6a2f | ||
|
|
6491cb1e23 | ||
|
|
c7124a5240 | ||
|
|
5537cb2858 | ||
|
|
aef5f6d666 | ||
|
|
8063391d0a | ||
|
|
0bbb12d688 | ||
|
|
19d775c435 |
79
.claude/skills/pr-address/SKILL.md
Normal file
79
.claude/skills/pr-address/SKILL.md
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
---
|
||||||
|
name: pr-address
|
||||||
|
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||||
|
user-invocable: true
|
||||||
|
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "1.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# PR Address
|
||||||
|
|
||||||
|
## Find the PR
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||||
|
gh pr view {N}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fetch comments (all sources)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||||
|
```
|
||||||
|
|
||||||
|
**Bots to watch for:**
|
||||||
|
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
||||||
|
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
||||||
|
- `coderabbitai[bot]` — automated review. Address actionable items.
|
||||||
|
|
||||||
|
## For each unaddressed comment
|
||||||
|
|
||||||
|
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||||
|
|
||||||
|
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||||
|
2. Commit and push the fix
|
||||||
|
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||||
|
|
||||||
|
| Comment type | How to reply |
|
||||||
|
|---|---|
|
||||||
|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||||
|
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||||
|
|
||||||
|
## Format and commit
|
||||||
|
|
||||||
|
After fixing, format the changed code:
|
||||||
|
|
||||||
|
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
||||||
|
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
||||||
|
|
||||||
|
If API routes changed, regenerate the frontend client:
|
||||||
|
```bash
|
||||||
|
cd autogpt_platform/backend && poetry run rest &
|
||||||
|
REST_PID=$!
|
||||||
|
trap "kill $REST_PID 2>/dev/null" EXIT
|
||||||
|
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
||||||
|
cd ../frontend && pnpm generate:api:force
|
||||||
|
kill $REST_PID 2>/dev/null; trap - EXIT
|
||||||
|
```
|
||||||
|
Never manually edit files in `src/app/api/__generated__/`.
|
||||||
|
|
||||||
|
Then commit and **push immediately** — never batch commits without pushing.
|
||||||
|
|
||||||
|
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||||
|
|
||||||
|
## The loop
|
||||||
|
|
||||||
|
```text
|
||||||
|
address comments → format → commit → push
|
||||||
|
→ re-check comments → fix new ones → push
|
||||||
|
→ wait for CI → re-check comments after CI settles
|
||||||
|
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||||
|
```
|
||||||
|
|
||||||
|
While CI runs, stay productive: run local tests, address remaining comments.
|
||||||
|
|
||||||
|
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||||
74
.claude/skills/pr-review/SKILL.md
Normal file
74
.claude/skills/pr-review/SKILL.md
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
---
|
||||||
|
name: pr-review
|
||||||
|
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
||||||
|
user-invocable: true
|
||||||
|
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "1.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# PR Review
|
||||||
|
|
||||||
|
## Find the PR
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||||
|
gh pr view {N}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Read the diff
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh pr diff {N}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fetch existing review comments
|
||||||
|
|
||||||
|
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||||
|
```
|
||||||
|
|
||||||
|
## What to check
|
||||||
|
|
||||||
|
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||||
|
|
||||||
|
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||||
|
|
||||||
|
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
||||||
|
|
||||||
|
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
||||||
|
|
||||||
|
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
||||||
|
|
||||||
|
## Output format
|
||||||
|
|
||||||
|
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
||||||
|
|
||||||
|
| Tier | Badge | Meaning |
|
||||||
|
|---|---|---|
|
||||||
|
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
||||||
|
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
||||||
|
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
||||||
|
| Nit | `🔵 **Nit**` | Style / wording |
|
||||||
|
|
||||||
|
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
||||||
|
|
||||||
|
## Post inline comments
|
||||||
|
|
||||||
|
For each finding, post an inline comment on the PR (do not just write a local report):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Get the latest commit SHA for the PR
|
||||||
|
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
||||||
|
|
||||||
|
# Post an inline comment on a specific file/line
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
||||||
|
-f body="🤖 🔴 **Blocker**: <description>" \
|
||||||
|
-f commit_id="$COMMIT_SHA" \
|
||||||
|
-f path="<file path>" \
|
||||||
|
-F line=<line number>
|
||||||
|
```
|
||||||
85
.claude/skills/worktree/SKILL.md
Normal file
85
.claude/skills/worktree/SKILL.md
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
---
|
||||||
|
name: worktree
|
||||||
|
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||||
|
user-invocable: true
|
||||||
|
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "3.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# Worktree Setup
|
||||||
|
|
||||||
|
## Create the worktree
|
||||||
|
|
||||||
|
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ROOT=$(git rev-parse --show-toplevel)
|
||||||
|
PARENT=$(dirname "$ROOT")
|
||||||
|
|
||||||
|
# From an existing branch
|
||||||
|
git worktree add "$PARENT/<NAME>" <branch-name>
|
||||||
|
|
||||||
|
# From a new branch off dev
|
||||||
|
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
||||||
|
```
|
||||||
|
|
||||||
|
## Copy environment files
|
||||||
|
|
||||||
|
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ROOT=$(git rev-parse --show-toplevel)
|
||||||
|
TARGET="$(dirname "$ROOT")/<NAME>"
|
||||||
|
|
||||||
|
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
||||||
|
if [ -f "$ROOT/$envpath/.env" ]; then
|
||||||
|
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
||||||
|
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
||||||
|
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||||
|
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
||||||
|
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
||||||
|
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
||||||
|
|
||||||
|
## Running the app (optional)
|
||||||
|
|
||||||
|
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||||
|
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||||
|
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||||
|
done
|
||||||
|
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
||||||
|
```
|
||||||
|
|
||||||
|
## CoPilot testing
|
||||||
|
|
||||||
|
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||||
|
|
||||||
|
## Cleanup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
||||||
|
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Alternative: Branchlet (optional)
|
||||||
|
|
||||||
|
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||||
|
```
|
||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -5,12 +5,14 @@ on:
|
|||||||
branches: [master, dev, ci-test*]
|
branches: [master, dev, ci-test*]
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-backend-ci.yml"
|
- ".github/workflows/platform-backend-ci.yml"
|
||||||
|
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||||
- "autogpt_platform/backend/**"
|
- "autogpt_platform/backend/**"
|
||||||
- "autogpt_platform/autogpt_libs/**"
|
- "autogpt_platform/autogpt_libs/**"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [master, dev, release-*]
|
branches: [master, dev, release-*]
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-backend-ci.yml"
|
- ".github/workflows/platform-backend-ci.yml"
|
||||||
|
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||||
- "autogpt_platform/backend/**"
|
- "autogpt_platform/backend/**"
|
||||||
- "autogpt_platform/autogpt_libs/**"
|
- "autogpt_platform/autogpt_libs/**"
|
||||||
merge_group:
|
merge_group:
|
||||||
|
|||||||
169
.github/workflows/platform-frontend-ci.yml
vendored
169
.github/workflows/platform-frontend-ci.yml
vendored
@@ -120,175 +120,6 @@ jobs:
|
|||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
e2e_test:
|
|
||||||
name: end-to-end tests
|
|
||||||
runs-on: big-boi
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
|
|
||||||
- name: Set up Platform - Copy default supabase .env
|
|
||||||
run: |
|
|
||||||
cp ../.env.default ../.env
|
|
||||||
|
|
||||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
|
||||||
run: |
|
|
||||||
cp ../backend/.env.default ../backend/.env
|
|
||||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
|
||||||
env:
|
|
||||||
# Used by E2E test data script to generate embeddings for approved store agents
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
|
|
||||||
- name: Set up Platform - Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
with:
|
|
||||||
driver: docker-container
|
|
||||||
driver-opts: network=host
|
|
||||||
|
|
||||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
|
||||||
uses: crazy-max/ghaction-github-runtime@v4
|
|
||||||
|
|
||||||
- name: Set up Platform - Build Docker images (with cache)
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
|
||||||
pip install pyyaml
|
|
||||||
|
|
||||||
# Resolve extends and generate a flat compose file that bake can understand
|
|
||||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
|
||||||
|
|
||||||
# Add cache configuration to the resolved compose file
|
|
||||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
|
||||||
--source docker-compose.resolved.yml \
|
|
||||||
--cache-from "type=gha" \
|
|
||||||
--cache-to "type=gha,mode=max" \
|
|
||||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
|
||||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
|
||||||
--git-ref "${{ github.ref }}"
|
|
||||||
|
|
||||||
# Build with bake using the resolved compose file (now includes cache config)
|
|
||||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
|
||||||
env:
|
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
|
||||||
|
|
||||||
- name: Set up tests - Cache E2E test data
|
|
||||||
id: e2e-data-cache
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
|
||||||
path: /tmp/e2e_test_data.sql
|
|
||||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
|
||||||
|
|
||||||
- name: Set up Platform - Start Supabase DB + Auth
|
|
||||||
run: |
|
|
||||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
|
||||||
echo "Waiting for database to be ready..."
|
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
|
||||||
echo "Waiting for auth service to be ready..."
|
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
|
||||||
|
|
||||||
- name: Set up Platform - Run migrations
|
|
||||||
run: |
|
|
||||||
echo "Running migrations..."
|
|
||||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
|
||||||
echo "✅ Migrations completed"
|
|
||||||
env:
|
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
|
||||||
|
|
||||||
- name: Set up tests - Load cached E2E test data
|
|
||||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
|
||||||
run: |
|
|
||||||
echo "✅ Found cached E2E test data, restoring..."
|
|
||||||
{
|
|
||||||
echo "SET session_replication_role = 'replica';"
|
|
||||||
cat /tmp/e2e_test_data.sql
|
|
||||||
echo "SET session_replication_role = 'origin';"
|
|
||||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
|
||||||
# Refresh materialized views after restore
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
|
||||||
|
|
||||||
echo "✅ E2E test data restored from cache"
|
|
||||||
|
|
||||||
- name: Set up Platform - Start (all other services)
|
|
||||||
run: |
|
|
||||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
|
||||||
echo "Waiting for rest_server to be ready..."
|
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
|
||||||
env:
|
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
|
||||||
|
|
||||||
- name: Set up tests - Create E2E test data
|
|
||||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
|
||||||
echo "Creating E2E test data..."
|
|
||||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
|
||||||
echo "❌ E2E test data creation failed!"
|
|
||||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
|
||||||
echo "Dumping database for cache..."
|
|
||||||
{
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
pg_dump -U postgres --data-only --column-inserts \
|
|
||||||
--table='auth.users' postgres
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
pg_dump -U postgres --data-only --column-inserts \
|
|
||||||
--schema=platform \
|
|
||||||
--exclude-table='platform._prisma_migrations' \
|
|
||||||
--exclude-table='platform.apscheduler_jobs' \
|
|
||||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
|
||||||
postgres
|
|
||||||
} > /tmp/e2e_test_data.sql
|
|
||||||
|
|
||||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
|
||||||
|
|
||||||
- name: Set up tests - Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up tests - Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Set up tests - Install dependencies
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Set up tests - Install browser 'chromium'
|
|
||||||
run: pnpm playwright install --with-deps chromium
|
|
||||||
|
|
||||||
- name: Run Playwright tests
|
|
||||||
run: pnpm test:no-build
|
|
||||||
continue-on-error: false
|
|
||||||
|
|
||||||
- name: Upload Playwright report
|
|
||||||
if: always()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: playwright-report
|
|
||||||
path: playwright-report
|
|
||||||
if-no-files-found: ignore
|
|
||||||
retention-days: 3
|
|
||||||
|
|
||||||
- name: Upload Playwright test results
|
|
||||||
if: always()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: playwright-test-results
|
|
||||||
path: test-results
|
|
||||||
if-no-files-found: ignore
|
|
||||||
retention-days: 3
|
|
||||||
|
|
||||||
- name: Print Final Docker Compose logs
|
|
||||||
if: always()
|
|
||||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
|
||||||
|
|
||||||
integration_test:
|
integration_test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
|
|||||||
312
.github/workflows/platform-fullstack-ci.yml
vendored
312
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,14 +1,18 @@
|
|||||||
name: AutoGPT Platform - Frontend CI
|
name: AutoGPT Platform - Full-stack CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [master, dev]
|
branches: [master, dev]
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-fullstack-ci.yml"
|
- ".github/workflows/platform-fullstack-ci.yml"
|
||||||
|
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||||
|
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||||
- "autogpt_platform/**"
|
- "autogpt_platform/**"
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-fullstack-ci.yml"
|
- ".github/workflows/platform-fullstack-ci.yml"
|
||||||
|
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||||
|
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||||
- "autogpt_platform/**"
|
- "autogpt_platform/**"
|
||||||
merge_group:
|
merge_group:
|
||||||
|
|
||||||
@@ -24,42 +28,28 @@ defaults:
|
|||||||
jobs:
|
jobs:
|
||||||
setup:
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Generate cache key
|
- name: Set up Node
|
||||||
id: cache-key
|
uses: actions/setup-node@v6
|
||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Cache dependencies
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies to populate cache
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
check-api-types:
|
||||||
runs-on: big-boi
|
name: check API types
|
||||||
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -67,70 +57,256 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
# ------------------------ Backend setup ------------------------
|
||||||
|
|
||||||
|
- name: Set up Backend - Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Set up Backend - Install Poetry
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
run: |
|
||||||
|
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||||
|
echo "Installing Poetry version ${POETRY_VERSION}"
|
||||||
|
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||||
|
|
||||||
|
- name: Set up Backend - Set up dependency cache
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry
|
||||||
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|
||||||
|
- name: Set up Backend - Install dependencies
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
run: poetry install
|
||||||
|
|
||||||
|
- name: Set up Backend - Generate Prisma client
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
|
- name: Set up Frontend - Export OpenAPI schema from Backend
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||||
|
|
||||||
|
# ------------------------ Frontend setup ------------------------
|
||||||
|
|
||||||
|
- name: Set up Frontend - Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Set up Frontend - Set up Node
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
cache: "pnpm"
|
||||||
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Set up Frontend - Install dependencies
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Copy default supabase .env
|
|
||||||
run: |
|
|
||||||
cp ../.env.default ../.env
|
|
||||||
|
|
||||||
- name: Copy backend .env
|
|
||||||
run: |
|
|
||||||
cp ../backend/.env.default ../backend/.env
|
|
||||||
|
|
||||||
- name: Run docker compose
|
|
||||||
run: |
|
|
||||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
|
||||||
path: ~/.pnpm-store
|
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Setup .env
|
- name: Set up Frontend - Format OpenAPI schema
|
||||||
run: cp .env.default .env
|
id: format-schema
|
||||||
|
run: pnpm prettier --write ./src/app/api/openapi.json
|
||||||
- name: Wait for services to be ready
|
|
||||||
run: |
|
|
||||||
echo "Waiting for rest_server to be ready..."
|
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
|
||||||
echo "Waiting for database to be ready..."
|
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
|
||||||
|
|
||||||
- name: Generate API queries
|
|
||||||
run: pnpm generate:api:force
|
|
||||||
|
|
||||||
- name: Check for API schema changes
|
- name: Check for API schema changes
|
||||||
run: |
|
run: |
|
||||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||||
echo ""
|
echo ""
|
||||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
echo "The openapi.json file has been modified after exporting the API schema."
|
||||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||||
echo "The API schema is now out of sync with the Front-end queries."
|
echo "The API schema is now out of sync with the Front-end queries."
|
||||||
echo ""
|
echo ""
|
||||||
echo "To fix this:"
|
echo "To fix this:"
|
||||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
echo "\nIn the backend directory:"
|
||||||
echo "2. Run 'pnpm generate:api' locally"
|
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
|
||||||
echo "3. Run 'pnpm types' locally"
|
echo "\nIn the frontend directory:"
|
||||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
|
||||||
echo "5. Commit and push your changes"
|
echo "3. Run 'pnpm generate:api'"
|
||||||
|
echo "4. Run 'pnpm types'"
|
||||||
|
echo "5. Fix any TypeScript errors that may have been introduced"
|
||||||
|
echo "6. Commit and push your changes"
|
||||||
echo ""
|
echo ""
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
echo "✅ No API schema changes detected"
|
echo "✅ No API schema changes detected"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Run Typescript checks
|
- name: Set up Frontend - Generate API client
|
||||||
|
id: generate-api-client
|
||||||
|
run: pnpm orval --config ./orval.config.ts
|
||||||
|
# Continue with type generation & check even if there are schema changes
|
||||||
|
if: success() || (steps.format-schema.outcome == 'success')
|
||||||
|
|
||||||
|
- name: Check for TypeScript errors
|
||||||
run: pnpm types
|
run: pnpm types
|
||||||
|
if: success() || (steps.generate-api-client.outcome == 'success')
|
||||||
|
|
||||||
|
e2e_test:
|
||||||
|
name: end-to-end tests
|
||||||
|
runs-on: big-boi
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Set up Platform - Copy default supabase .env
|
||||||
|
run: |
|
||||||
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
|
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||||
|
run: |
|
||||||
|
cp ../backend/.env.default ../backend/.env
|
||||||
|
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||||
|
env:
|
||||||
|
# Used by E2E test data script to generate embeddings for approved store agents
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
|
- name: Set up Platform - Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
driver: docker-container
|
||||||
|
driver-opts: network=host
|
||||||
|
|
||||||
|
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||||
|
uses: crazy-max/ghaction-github-runtime@v4
|
||||||
|
|
||||||
|
- name: Set up Platform - Build Docker images (with cache)
|
||||||
|
working-directory: autogpt_platform
|
||||||
|
run: |
|
||||||
|
pip install pyyaml
|
||||||
|
|
||||||
|
# Resolve extends and generate a flat compose file that bake can understand
|
||||||
|
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||||
|
|
||||||
|
# Add cache configuration to the resolved compose file
|
||||||
|
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||||
|
--source docker-compose.resolved.yml \
|
||||||
|
--cache-from "type=gha" \
|
||||||
|
--cache-to "type=gha,mode=max" \
|
||||||
|
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||||
|
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||||
|
--git-ref "${{ github.ref }}"
|
||||||
|
|
||||||
|
# Build with bake using the resolved compose file (now includes cache config)
|
||||||
|
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||||
|
env:
|
||||||
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
|
- name: Set up tests - Cache E2E test data
|
||||||
|
id: e2e-data-cache
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: /tmp/e2e_test_data.sql
|
||||||
|
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
|
||||||
|
|
||||||
|
- name: Set up Platform - Start Supabase DB + Auth
|
||||||
|
run: |
|
||||||
|
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||||
|
echo "Waiting for database to be ready..."
|
||||||
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||||
|
echo "Waiting for auth service to be ready..."
|
||||||
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||||
|
|
||||||
|
- name: Set up Platform - Run migrations
|
||||||
|
run: |
|
||||||
|
echo "Running migrations..."
|
||||||
|
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||||
|
echo "✅ Migrations completed"
|
||||||
|
env:
|
||||||
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
|
- name: Set up tests - Load cached E2E test data
|
||||||
|
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||||
|
run: |
|
||||||
|
echo "✅ Found cached E2E test data, restoring..."
|
||||||
|
{
|
||||||
|
echo "SET session_replication_role = 'replica';"
|
||||||
|
cat /tmp/e2e_test_data.sql
|
||||||
|
echo "SET session_replication_role = 'origin';"
|
||||||
|
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||||
|
# Refresh materialized views after restore
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||||
|
|
||||||
|
echo "✅ E2E test data restored from cache"
|
||||||
|
|
||||||
|
- name: Set up Platform - Start (all other services)
|
||||||
|
run: |
|
||||||
|
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||||
|
echo "Waiting for rest_server to be ready..."
|
||||||
|
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
|
env:
|
||||||
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
|
- name: Set up tests - Create E2E test data
|
||||||
|
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||||
|
run: |
|
||||||
|
echo "Creating E2E test data..."
|
||||||
|
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||||
|
echo "❌ E2E test data creation failed!"
|
||||||
|
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||||
|
echo "Dumping database for cache..."
|
||||||
|
{
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
pg_dump -U postgres --data-only --column-inserts \
|
||||||
|
--table='auth.users' postgres
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
pg_dump -U postgres --data-only --column-inserts \
|
||||||
|
--schema=platform \
|
||||||
|
--exclude-table='platform._prisma_migrations' \
|
||||||
|
--exclude-table='platform.apscheduler_jobs' \
|
||||||
|
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||||
|
postgres
|
||||||
|
} > /tmp/e2e_test_data.sql
|
||||||
|
|
||||||
|
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||||
|
|
||||||
|
- name: Set up tests - Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Set up tests - Set up Node
|
||||||
|
uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
cache: "pnpm"
|
||||||
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
|
|
||||||
|
- name: Set up tests - Install dependencies
|
||||||
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
|
- name: Set up tests - Install browser 'chromium'
|
||||||
|
run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
|
- name: Run Playwright tests
|
||||||
|
run: pnpm test:no-build
|
||||||
|
continue-on-error: false
|
||||||
|
|
||||||
|
- name: Upload Playwright report
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: playwright-report
|
||||||
|
path: playwright-report
|
||||||
|
if-no-files-found: ignore
|
||||||
|
retention-days: 3
|
||||||
|
|
||||||
|
- name: Upload Playwright test results
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: playwright-test-results
|
||||||
|
path: test-results
|
||||||
|
if-no-files-found: ignore
|
||||||
|
retention-days: 3
|
||||||
|
|
||||||
|
- name: Print Final Docker Compose logs
|
||||||
|
if: always()
|
||||||
|
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||||
|
|||||||
@@ -60,9 +60,12 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|
||||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
|
||||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
When fetching comments manually:
|
||||||
|
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
||||||
|
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
||||||
|
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||||
|
|
||||||
### Conventional Commits
|
### Conventional Commits
|
||||||
|
|
||||||
|
|||||||
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.auth_activities
|
||||||
|
-- Looker source alias: ds49 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Tracks authentication events (login, logout, SSO, password
|
||||||
|
-- reset, etc.) from Supabase's internal audit log.
|
||||||
|
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.audit_log_entries — Supabase internal auth event log
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||||
|
-- actor_id TEXT User ID who triggered the event
|
||||||
|
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||||
|
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days from current date
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Daily login counts
|
||||||
|
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||||
|
-- FROM analytics.auth_activities
|
||||||
|
-- WHERE action = 'login'
|
||||||
|
-- GROUP BY 1 ORDER BY 1;
|
||||||
|
--
|
||||||
|
-- -- SSO vs password login breakdown
|
||||||
|
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||||
|
-- WHERE action = 'login' GROUP BY 1;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
created_at,
|
||||||
|
payload->>'actor_id' AS actor_id,
|
||||||
|
payload->>'actor_via_sso' AS actor_via_sso,
|
||||||
|
payload->>'action' AS action
|
||||||
|
FROM auth.audit_log_entries
|
||||||
|
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||||
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.graph_execution
|
||||||
|
-- Looker source alias: ds16 | Charts: 21
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per agent graph execution (last 90 days).
|
||||||
|
-- Unpacks the JSONB stats column into individual numeric columns
|
||||||
|
-- and normalises the executionStatus — runs that failed due to
|
||||||
|
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||||
|
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||||
|
-- to allow safe grouping.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records
|
||||||
|
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||||
|
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- id TEXT Execution UUID
|
||||||
|
-- agentGraphId TEXT Agent graph UUID
|
||||||
|
-- agentGraphVersion INT Graph version number
|
||||||
|
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||||
|
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||||
|
-- updatedAt TIMESTAMPTZ Last status update time
|
||||||
|
-- userId TEXT Owner user UUID
|
||||||
|
-- agentGraphName TEXT Human-readable agent name
|
||||||
|
-- cputime DECIMAL Total CPU seconds consumed
|
||||||
|
-- walltime DECIMAL Total wall-clock seconds
|
||||||
|
-- node_count DECIMAL Number of nodes in the graph
|
||||||
|
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||||
|
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||||
|
-- execution_cost DECIMAL Credit cost of this execution
|
||||||
|
-- correctness_score FLOAT AI correctness score (if available)
|
||||||
|
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||||
|
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Daily execution counts by status
|
||||||
|
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||||
|
-- FROM analytics.graph_execution
|
||||||
|
-- GROUP BY 1, 2 ORDER BY 1;
|
||||||
|
--
|
||||||
|
-- -- Average cost per execution by agent
|
||||||
|
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||||
|
-- FROM analytics.graph_execution
|
||||||
|
-- WHERE "executionStatus" = 'COMPLETED'
|
||||||
|
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||||
|
--
|
||||||
|
-- -- Top error messages
|
||||||
|
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||||
|
-- FROM analytics.graph_execution
|
||||||
|
-- WHERE "executionStatus" = 'FAILED'
|
||||||
|
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
ge."id" AS id,
|
||||||
|
ge."agentGraphId" AS agentGraphId,
|
||||||
|
ge."agentGraphVersion" AS agentGraphVersion,
|
||||||
|
CASE
|
||||||
|
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||||
|
AND (
|
||||||
|
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||||
|
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||||
|
)
|
||||||
|
THEN 'NO_CREDITS'
|
||||||
|
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||||
|
END AS executionStatus,
|
||||||
|
ge."createdAt" AS createdAt,
|
||||||
|
ge."updatedAt" AS updatedAt,
|
||||||
|
ge."userId" AS userId,
|
||||||
|
g."name" AS agentGraphName,
|
||||||
|
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||||
|
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||||
|
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||||
|
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||||
|
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||||
|
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||||
|
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||||
|
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||||
|
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||||
|
'\1\2/...', 'gi'
|
||||||
|
),
|
||||||
|
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||||
|
) AS groupedErrorMessage
|
||||||
|
FROM platform."AgentGraphExecution" ge
|
||||||
|
LEFT JOIN platform."AgentGraph" g
|
||||||
|
ON ge."agentGraphId" = g."id"
|
||||||
|
AND ge."agentGraphVersion" = g."version"
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||||
|
"userId", "agentGraphId",
|
||||||
|
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||||
|
FROM platform."LibraryAgent"
|
||||||
|
WHERE "isDeleted" = FALSE
|
||||||
|
AND "isArchived" = FALSE
|
||||||
|
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||||
|
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||||
|
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||||
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.node_block_execution
|
||||||
|
-- Looker source alias: ds14 | Charts: 11
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per node (block) execution (last 90 days).
|
||||||
|
-- Unpacks stats JSONB and joins to identify which block type
|
||||||
|
-- was run. For failed nodes, joins the error output and
|
||||||
|
-- scrubs it for safe grouping.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentNodeExecution — Node execution records
|
||||||
|
-- platform.AgentNode — Node → block mapping
|
||||||
|
-- platform.AgentBlock — Block name/ID
|
||||||
|
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- id TEXT Node execution UUID
|
||||||
|
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||||
|
-- agentNodeId TEXT Node UUID within the graph
|
||||||
|
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||||
|
-- addedTime TIMESTAMPTZ When the node was queued
|
||||||
|
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||||
|
-- startedTime TIMESTAMPTZ When execution started
|
||||||
|
-- endedTime TIMESTAMPTZ When execution finished
|
||||||
|
-- inputSize BIGINT Input payload size in bytes
|
||||||
|
-- outputSize BIGINT Output payload size in bytes
|
||||||
|
-- walltime NUMERIC Wall-clock seconds for this node
|
||||||
|
-- cputime NUMERIC CPU seconds for this node
|
||||||
|
-- llmRetryCount INT Number of LLM retries
|
||||||
|
-- llmCallCount INT Number of LLM API calls made
|
||||||
|
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||||
|
-- outputTokenCount BIGINT LLM output tokens produced
|
||||||
|
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||||
|
-- blockId TEXT Block UUID
|
||||||
|
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||||
|
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Most-used blocks by execution count
|
||||||
|
-- SELECT "blockName", COUNT(*) AS executions,
|
||||||
|
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||||
|
-- FROM analytics.node_block_execution
|
||||||
|
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||||
|
--
|
||||||
|
-- -- Average LLM token usage per block
|
||||||
|
-- SELECT "blockName",
|
||||||
|
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||||
|
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||||
|
-- FROM analytics.node_block_execution
|
||||||
|
-- WHERE "llmCallCount" > 0
|
||||||
|
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||||
|
--
|
||||||
|
-- -- Top failure reasons
|
||||||
|
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||||
|
-- FROM analytics.node_block_execution
|
||||||
|
-- WHERE "executionStatus" = 'FAILED'
|
||||||
|
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
ne."id" AS id,
|
||||||
|
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||||
|
ne."agentNodeId" AS agentNodeId,
|
||||||
|
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||||
|
ne."addedTime" AS addedTime,
|
||||||
|
ne."queuedTime" AS queuedTime,
|
||||||
|
ne."startedTime" AS startedTime,
|
||||||
|
ne."endedTime" AS endedTime,
|
||||||
|
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||||
|
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||||
|
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||||
|
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||||
|
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||||
|
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||||
|
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||||
|
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||||
|
b."name" AS blockName,
|
||||||
|
b."id" AS blockId,
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
TRIM(BOTH '"' FROM eio."data"::text),
|
||||||
|
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||||
|
'\1\2/...', 'gi'
|
||||||
|
),
|
||||||
|
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||||
|
) AS groupedErrorMessage,
|
||||||
|
eio."data" AS errorMessage
|
||||||
|
FROM platform."AgentNodeExecution" ne
|
||||||
|
LEFT JOIN platform."AgentNode" nd
|
||||||
|
ON ne."agentNodeId" = nd."id"
|
||||||
|
LEFT JOIN platform."AgentBlock" b
|
||||||
|
ON nd."agentBlockId" = b."id"
|
||||||
|
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||||
|
ON eio."referencedByOutputExecId" = ne."id"
|
||||||
|
AND eio."name" = 'error'
|
||||||
|
AND ne."executionStatus" = 'FAILED'
|
||||||
|
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||||
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_agent
|
||||||
|
-- Looker source alias: ds35 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention broken down per individual agent.
|
||||||
|
-- Cohort = week of a user's first use of THAT specific agent.
|
||||||
|
-- Tells you which agents keep users coming back vs. one-shot
|
||||||
|
-- use. Only includes cohorts from the last 180 days.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||||
|
-- platform.AgentGraph — Agent names
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- agent_id TEXT Agent graph UUID
|
||||||
|
-- agent_label TEXT 'AgentName [first8chars]'
|
||||||
|
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||||
|
-- cohort_week_start DATE Week users first ran this agent
|
||||||
|
-- cohort_label TEXT ISO week label
|
||||||
|
-- cohort_label_n TEXT ISO week label with cohort size
|
||||||
|
-- user_lifetime_week INT Weeks since first use of this agent
|
||||||
|
-- cohort_users BIGINT Users in this cohort for this agent
|
||||||
|
-- active_users BIGINT Users who ran the agent again in week k
|
||||||
|
-- retention_rate FLOAT active_users / cohort_users
|
||||||
|
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||||
|
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Best-retained agents at week 2
|
||||||
|
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||||
|
-- FROM analytics.retention_agent
|
||||||
|
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||||
|
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||||
|
--
|
||||||
|
-- -- Agents with most unique users
|
||||||
|
-- SELECT DISTINCT agent_label, agent_total_users
|
||||||
|
-- FROM analytics.retention_agent
|
||||||
|
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||||
|
e."createdAt"::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||||
|
FROM platform."AgentGraphExecution" e
|
||||||
|
),
|
||||||
|
first_use AS (
|
||||||
|
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1,2
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||||
|
),
|
||||||
|
active_counts AS (
|
||||||
|
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||||
|
),
|
||||||
|
cohort_sizes AS (
|
||||||
|
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
),
|
||||||
|
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||||
|
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||||
|
SELECT
|
||||||
|
g.agent_id,
|
||||||
|
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||||
|
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(ac.active_users,0) AS active_users,
|
||||||
|
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||||
|
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||||
|
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||||
|
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_execution_daily
|
||||||
|
-- Looker source alias: ds111 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Daily cohort retention based on agent executions.
|
||||||
|
-- Cohort anchor = day of user's FIRST ever execution.
|
||||||
|
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||||
|
-- Great for early engagement analysis (did users run another
|
||||||
|
-- agent the next day?).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- Same pattern as retention_login_daily.
|
||||||
|
-- cohort_day_start = day of first execution (not first login)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Day-3 execution retention
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||||
|
-- FROM analytics.retention_execution_daily
|
||||||
|
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||||
|
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||||
|
),
|
||||||
|
first_exec AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||||
|
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||||
|
user_day_age AS (
|
||||||
|
SELECT ad.user_id, fe.cohort_day_start,
|
||||||
|
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||||
|
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||||
|
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_day_start,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_day, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||||
|
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_execution_weekly
|
||||||
|
-- Looker source alias: ds92 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention based on agent executions.
|
||||||
|
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||||
|
-- (not first login). Only includes cohorts from the last 180 days.
|
||||||
|
-- Useful when you care about product engagement, not just visits.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- Same pattern as retention_login_weekly.
|
||||||
|
-- cohort_week_start = week of first execution (not first login)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Week-2 execution retention
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded
|
||||||
|
-- FROM analytics.retention_execution_weekly
|
||||||
|
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||||
|
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||||
|
),
|
||||||
|
first_exec AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, fe.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||||
|
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||||
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_login_daily
|
||||||
|
-- Looker source alias: ds112 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Daily cohort retention based on login sessions.
|
||||||
|
-- Same logic as retention_login_weekly but at day granularity,
|
||||||
|
-- showing up to day 30 for cohorts from the last 90 days.
|
||||||
|
-- Useful for analysing early activation (days 1-7) in detail.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login session records
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||||
|
-- cohort_day_start DATE First day the cohort logged in
|
||||||
|
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||||
|
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||||
|
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||||
|
-- cohort_users BIGINT Total users in cohort
|
||||||
|
-- active_users_bounded BIGINT Users active on exactly day k
|
||||||
|
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||||
|
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||||
|
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||||
|
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Day-1 retention rate (came back next day)
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||||
|
-- FROM analytics.retention_login_daily
|
||||||
|
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||||
|
--
|
||||||
|
-- -- Average retention curve across all cohorts
|
||||||
|
-- SELECT user_lifetime_day,
|
||||||
|
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||||
|
-- FROM analytics.retention_login_daily
|
||||||
|
-- GROUP BY 1 ORDER BY 1;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||||
|
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||||
|
),
|
||||||
|
first_login AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||||
|
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||||
|
user_day_age AS (
|
||||||
|
SELECT ad.user_id, fl.cohort_day_start,
|
||||||
|
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||||
|
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||||
|
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_day_start,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_day, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||||
|
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_login_onboarded_weekly
|
||||||
|
-- Looker source alias: ds101 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention from login sessions, restricted to
|
||||||
|
-- users who "onboarded" — defined as running at least one
|
||||||
|
-- agent within 365 days of their first login.
|
||||||
|
-- Filters out users who signed up but never activated,
|
||||||
|
-- giving a cleaner view of engaged-user retention.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login session records
|
||||||
|
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||||
|
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||||
|
-- Only difference: cohort is filtered to onboarded users only.
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Compare week-4 retention: all users vs onboarded only
|
||||||
|
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||||
|
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||||
|
-- UNION ALL
|
||||||
|
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||||
|
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||||
|
events AS (
|
||||||
|
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||||
|
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||||
|
),
|
||||||
|
first_login_all AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
),
|
||||||
|
onboarders AS (
|
||||||
|
SELECT fl.user_id FROM first_login_all fl
|
||||||
|
WHERE EXISTS (
|
||||||
|
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||||
|
WHERE e."userId"::text = fl.user_id
|
||||||
|
AND e."createdAt" >= fl.first_login_time
|
||||||
|
AND e."createdAt" < fl.first_login_time
|
||||||
|
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, fl.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||||
|
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||||
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_login_weekly
|
||||||
|
-- Looker source alias: ds83 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention based on login sessions.
|
||||||
|
-- Users are grouped by the ISO week of their first ever login.
|
||||||
|
-- For each cohort × lifetime-week combination, outputs both:
|
||||||
|
-- - bounded rate: % active in exactly that week
|
||||||
|
-- - unbounded rate: % who were ever active on or after that week
|
||||||
|
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login session records
|
||||||
|
--
|
||||||
|
-- HOW TO READ THE OUTPUT
|
||||||
|
-- cohort_week_start The Monday of the week users first logged in
|
||||||
|
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||||
|
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||||
|
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- cohort_week_start DATE First day of the cohort's signup week
|
||||||
|
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||||
|
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||||
|
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||||
|
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||||
|
-- active_users_bounded BIGINT Users active in exactly week k
|
||||||
|
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||||
|
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||||
|
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||||
|
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Week-1 retention rate per cohort
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||||
|
-- FROM analytics.retention_login_weekly
|
||||||
|
-- WHERE user_lifetime_week = 1
|
||||||
|
-- ORDER BY cohort_week_start;
|
||||||
|
--
|
||||||
|
-- -- Overall average retention curve (all cohorts combined)
|
||||||
|
-- SELECT user_lifetime_week,
|
||||||
|
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||||
|
-- FROM analytics.retention_login_weekly
|
||||||
|
-- GROUP BY 1 ORDER BY 1;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks),
|
||||||
|
events AS (
|
||||||
|
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||||
|
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||||
|
),
|
||||||
|
first_login AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, fl.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||||
|
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||||
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_block_spending
|
||||||
|
-- Looker source alias: ds6 | Charts: 5
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per credit transaction (last 90 days).
|
||||||
|
-- Shows how users spend credits broken down by block type,
|
||||||
|
-- LLM provider and model. Joins node execution stats for
|
||||||
|
-- token-level detail.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.CreditTransaction — Credit debit/credit records
|
||||||
|
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- transactionKey TEXT Unique transaction identifier
|
||||||
|
-- userId TEXT User who was charged
|
||||||
|
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||||
|
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||||
|
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||||
|
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||||
|
-- blockId TEXT Block UUID that triggered the spend
|
||||||
|
-- blockName TEXT Human-readable block name
|
||||||
|
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||||
|
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||||
|
-- node_exec_id TEXT Linked node execution UUID
|
||||||
|
-- llm_call_count INT LLM API calls made in that execution
|
||||||
|
-- llm_retry_count INT LLM retries in that execution
|
||||||
|
-- llm_input_token_count INT Input tokens consumed
|
||||||
|
-- llm_output_token_count INT Output tokens produced
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Total spend per user (last 90 days)
|
||||||
|
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||||
|
-- FROM analytics.user_block_spending
|
||||||
|
-- WHERE "transactionType" = 'USAGE'
|
||||||
|
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||||
|
--
|
||||||
|
-- -- Spend by LLM provider + model
|
||||||
|
-- SELECT "llm_provider", "llm_model",
|
||||||
|
-- SUM("negativeAmount") AS total_cost,
|
||||||
|
-- SUM("llm_input_token_count") AS input_tokens,
|
||||||
|
-- SUM("llm_output_token_count") AS output_tokens
|
||||||
|
-- FROM analytics.user_block_spending
|
||||||
|
-- WHERE "llm_provider" IS NOT NULL
|
||||||
|
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
c."transactionKey" AS transactionKey,
|
||||||
|
c."userId" AS userId,
|
||||||
|
c."amount" AS amount,
|
||||||
|
c."amount" * -1 AS negativeAmount,
|
||||||
|
c."type" AS transactionType,
|
||||||
|
c."createdAt" AS transactionTime,
|
||||||
|
c.metadata->>'block_id' AS blockId,
|
||||||
|
c.metadata->>'block' AS blockName,
|
||||||
|
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||||
|
c.metadata->'input'->>'model' AS llm_model,
|
||||||
|
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||||
|
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||||
|
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||||
|
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||||
|
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||||
|
FROM platform."CreditTransaction" c
|
||||||
|
LEFT JOIN platform."AgentNodeExecution" ne
|
||||||
|
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||||
|
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||||
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_onboarding
|
||||||
|
-- Looker source alias: ds68 | Charts: 3
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per user onboarding record. Contains the user's
|
||||||
|
-- stated usage reason, selected integrations, completed
|
||||||
|
-- onboarding steps and optional first agent selection.
|
||||||
|
-- Full history (no date filter) since onboarding happens
|
||||||
|
-- once per user.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.UserOnboarding — Onboarding state per user
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- id TEXT Onboarding record UUID
|
||||||
|
-- createdAt TIMESTAMPTZ When onboarding started
|
||||||
|
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||||
|
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||||
|
-- integrations TEXT[] Array of integration names the user selected
|
||||||
|
-- userId TEXT User UUID
|
||||||
|
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||||
|
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Usage reason breakdown
|
||||||
|
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||||
|
--
|
||||||
|
-- -- Completion rate per step
|
||||||
|
-- SELECT step, COUNT(*) AS users_completed
|
||||||
|
-- FROM analytics.user_onboarding
|
||||||
|
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||||
|
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
"createdAt",
|
||||||
|
"updatedAt",
|
||||||
|
"usageReason",
|
||||||
|
integrations,
|
||||||
|
"userId",
|
||||||
|
"completedSteps",
|
||||||
|
"selectedStoreListingVersionId"
|
||||||
|
FROM platform."UserOnboarding"
|
||||||
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_onboarding_funnel
|
||||||
|
-- Looker source alias: ds74 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Pre-aggregated onboarding funnel showing how many users
|
||||||
|
-- completed each step and the drop-off percentage from the
|
||||||
|
-- previous step. One row per onboarding step (all 22 steps
|
||||||
|
-- always present, even with 0 completions — prevents sparse
|
||||||
|
-- gaps from making LAG compare the wrong predecessors).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||||
|
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||||
|
-- users_completed BIGINT Distinct users who completed this step
|
||||||
|
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||||
|
--
|
||||||
|
-- STEP ORDER
|
||||||
|
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||||
|
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||||
|
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||||
|
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||||
|
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||||
|
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||||
|
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||||
|
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Users who started onboarding in the last 90 days
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Full funnel
|
||||||
|
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||||
|
--
|
||||||
|
-- -- Biggest drop-off point
|
||||||
|
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||||
|
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH all_steps AS (
|
||||||
|
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||||
|
-- are always present, keeping LAG comparisons correct.
|
||||||
|
SELECT step_name, step_order
|
||||||
|
FROM (VALUES
|
||||||
|
('WELCOME', 1),
|
||||||
|
('USAGE_REASON', 2),
|
||||||
|
('INTEGRATIONS', 3),
|
||||||
|
('AGENT_CHOICE', 4),
|
||||||
|
('AGENT_NEW_RUN', 5),
|
||||||
|
('AGENT_INPUT', 6),
|
||||||
|
('CONGRATS', 7),
|
||||||
|
('GET_RESULTS', 8),
|
||||||
|
('MARKETPLACE_VISIT', 9),
|
||||||
|
('MARKETPLACE_ADD_AGENT', 10),
|
||||||
|
('MARKETPLACE_RUN_AGENT', 11),
|
||||||
|
('BUILDER_OPEN', 12),
|
||||||
|
('BUILDER_SAVE_AGENT', 13),
|
||||||
|
('BUILDER_RUN_AGENT', 14),
|
||||||
|
('VISIT_COPILOT', 15),
|
||||||
|
('RE_RUN_AGENT', 16),
|
||||||
|
('SCHEDULE_AGENT', 17),
|
||||||
|
('RUN_AGENTS', 18),
|
||||||
|
('RUN_3_DAYS', 19),
|
||||||
|
('TRIGGER_WEBHOOK', 20),
|
||||||
|
('RUN_14_DAYS', 21),
|
||||||
|
('RUN_AGENTS_100', 22)
|
||||||
|
) AS t(step_name, step_order)
|
||||||
|
),
|
||||||
|
raw AS (
|
||||||
|
SELECT
|
||||||
|
u."userId",
|
||||||
|
step_txt::text AS step
|
||||||
|
FROM platform."UserOnboarding" u
|
||||||
|
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||||
|
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||||
|
),
|
||||||
|
step_counts AS (
|
||||||
|
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||||
|
FROM raw GROUP BY step
|
||||||
|
),
|
||||||
|
funnel AS (
|
||||||
|
SELECT
|
||||||
|
a.step_name AS step,
|
||||||
|
a.step_order,
|
||||||
|
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||||
|
ROUND(
|
||||||
|
100.0 * COALESCE(sc.users_completed, 0)
|
||||||
|
/ NULLIF(
|
||||||
|
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||||
|
0
|
||||||
|
),
|
||||||
|
2
|
||||||
|
) AS pct_from_prev
|
||||||
|
FROM all_steps a
|
||||||
|
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||||
|
)
|
||||||
|
SELECT * FROM funnel ORDER BY step_order
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_onboarding_integration
|
||||||
|
-- Looker source alias: ds75 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Pre-aggregated count of users who selected each integration
|
||||||
|
-- during onboarding. One row per integration type, sorted
|
||||||
|
-- by popularity.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.UserOnboarding — integrations array column
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||||
|
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Users who started onboarding in the last 90 days
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Full integration popularity ranking
|
||||||
|
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||||
|
--
|
||||||
|
-- -- Top 5 integrations
|
||||||
|
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH exploded AS (
|
||||||
|
SELECT
|
||||||
|
u."userId" AS user_id,
|
||||||
|
UNNEST(u."integrations") AS integration
|
||||||
|
FROM platform."UserOnboarding" u
|
||||||
|
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
integration,
|
||||||
|
COUNT(DISTINCT user_id) AS users_with_integration
|
||||||
|
FROM exploded
|
||||||
|
WHERE integration IS NOT NULL AND integration <> ''
|
||||||
|
GROUP BY integration
|
||||||
|
ORDER BY users_with_integration DESC
|
||||||
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.users_activities
|
||||||
|
-- Looker source alias: ds56 | Charts: 5
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per user with lifetime activity summary.
|
||||||
|
-- Joins login sessions with agent graphs, executions and
|
||||||
|
-- node-level runs to give a full picture of how engaged
|
||||||
|
-- each user is. Includes a convenience flag for 7-day
|
||||||
|
-- activation (did the user return at least 7 days after
|
||||||
|
-- their first login?).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login/session records
|
||||||
|
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||||
|
-- platform.AgentGraphExecution — Agent run history
|
||||||
|
-- platform.AgentNodeExecution — Individual block execution history
|
||||||
|
--
|
||||||
|
-- PERFORMANCE NOTE
|
||||||
|
-- Each CTE aggregates its own table independently by userId.
|
||||||
|
-- This avoids the fan-out that occurs when driving every join
|
||||||
|
-- from user_logins across the two largest tables
|
||||||
|
-- (AgentGraphExecution and AgentNodeExecution).
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- user_id TEXT Supabase user UUID
|
||||||
|
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||||
|
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||||
|
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||||
|
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||||
|
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||||
|
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||||
|
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||||
|
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||||
|
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||||
|
-- node_execution_count BIGINT Total node executions across all runs
|
||||||
|
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||||
|
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||||
|
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||||
|
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||||
|
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||||
|
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||||
|
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||||
|
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Users who ran at least one agent and returned after 7 days
|
||||||
|
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||||
|
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||||
|
--
|
||||||
|
-- -- Top 10 most active users by agent runs
|
||||||
|
-- SELECT user_id, agent_runs, node_execution_count
|
||||||
|
-- FROM analytics.users_activities
|
||||||
|
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||||
|
--
|
||||||
|
-- -- 7-day activation rate
|
||||||
|
-- SELECT
|
||||||
|
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||||
|
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||||
|
-- AS activation_rate
|
||||||
|
-- FROM analytics.users_activities;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH user_logins AS (
|
||||||
|
SELECT
|
||||||
|
user_id::text AS user_id,
|
||||||
|
MIN(created_at) AS first_login_time,
|
||||||
|
MAX(created_at) AS last_login_time,
|
||||||
|
GREATEST(
|
||||||
|
MAX(refreshed_at)::timestamptz,
|
||||||
|
MAX(created_at)::timestamptz
|
||||||
|
) AS last_visit_time
|
||||||
|
FROM auth.sessions
|
||||||
|
GROUP BY user_id
|
||||||
|
),
|
||||||
|
user_agents AS (
|
||||||
|
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||||
|
SELECT
|
||||||
|
"userId"::text AS user_id,
|
||||||
|
MAX("updatedAt") AS last_agent_save_time,
|
||||||
|
COUNT(DISTINCT "id") AS agent_count
|
||||||
|
FROM platform."AgentGraph"
|
||||||
|
WHERE "isActive"
|
||||||
|
GROUP BY "userId"
|
||||||
|
),
|
||||||
|
user_graph_runs AS (
|
||||||
|
-- Aggregate AgentGraphExecution directly by userId
|
||||||
|
SELECT
|
||||||
|
"userId"::text AS user_id,
|
||||||
|
MIN("createdAt") AS first_agent_run_time,
|
||||||
|
MAX("createdAt") AS last_agent_run_time,
|
||||||
|
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||||
|
COUNT("id") AS agent_runs
|
||||||
|
FROM platform."AgentGraphExecution"
|
||||||
|
GROUP BY "userId"
|
||||||
|
),
|
||||||
|
user_node_runs AS (
|
||||||
|
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||||
|
-- single join to AgentGraphExecution instead of fanning out from
|
||||||
|
-- user_logins through both large tables.
|
||||||
|
SELECT
|
||||||
|
g."userId"::text AS user_id,
|
||||||
|
COUNT(*) AS node_execution_count,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||||
|
FROM platform."AgentNodeExecution" n
|
||||||
|
JOIN platform."AgentGraphExecution" g
|
||||||
|
ON g."id" = n."agentGraphExecutionId"
|
||||||
|
GROUP BY g."userId"
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
ul.user_id,
|
||||||
|
ul.first_login_time,
|
||||||
|
ul.last_login_time,
|
||||||
|
ul.last_visit_time,
|
||||||
|
ua.last_agent_save_time,
|
||||||
|
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||||
|
gr.first_agent_run_time,
|
||||||
|
gr.last_agent_run_time,
|
||||||
|
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||||
|
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||||
|
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||||
|
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||||
|
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||||
|
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||||
|
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||||
|
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||||
|
CASE
|
||||||
|
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||||
|
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||||
|
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||||
|
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||||
|
ELSE NULL
|
||||||
|
END AS is_active_after_7d,
|
||||||
|
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||||
|
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||||
|
FROM user_logins ul
|
||||||
|
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||||
|
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||||
|
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||||
@@ -37,6 +37,10 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
|||||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||||
|
|
||||||
|
## ===== SIGNUP / INVITE GATE ===== ##
|
||||||
|
# Set to true to require an invite before users can sign up
|
||||||
|
ENABLE_INVITE_GATE=false
|
||||||
|
|
||||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||||
# Platform URLs (set these for webhooks and OAuth to work)
|
# Platform URLs (set these for webhooks and OAuth to work)
|
||||||
PLATFORM_BASE_URL=http://localhost:8000
|
PLATFORM_BASE_URL=http://localhost:8000
|
||||||
|
|||||||
@@ -58,10 +58,31 @@ poetry run pytest path/to/test.py --snapshot-update
|
|||||||
- **Authentication**: JWT-based with Supabase integration
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
|
## Code Style
|
||||||
|
|
||||||
|
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||||
|
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||||
|
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||||
|
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||||
|
- **List comprehensions** over manual loop-and-append
|
||||||
|
- **Early return** — guard clauses first, avoid deep nesting
|
||||||
|
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||||
|
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||||
|
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||||
|
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||||
|
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||||
|
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||||
|
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||||
|
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||||
|
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||||
|
|
||||||
## Testing Approach
|
## Testing Approach
|
||||||
|
|
||||||
- Uses pytest with snapshot testing for API responses
|
- Uses pytest with snapshot testing for API responses
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
|
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||||
|
- After refactoring, update mock targets to match new module paths
|
||||||
|
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||||
|
|
||||||
## Database Schema
|
## Database Schema
|
||||||
|
|
||||||
@@ -157,6 +178,16 @@ yield "image_url", result_url
|
|||||||
3. Write tests alongside the route file
|
3. Write tests alongside the route file
|
||||||
4. Run `poetry run test` to verify
|
4. Run `poetry run test` to verify
|
||||||
|
|
||||||
|
## Workspace & Media Files
|
||||||
|
|
||||||
|
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
|
||||||
|
- Working on CoPilot file upload/download features
|
||||||
|
- Building blocks that handle `MediaFileType` inputs/outputs
|
||||||
|
- Modifying `WorkspaceManager` or `store_media_file()`
|
||||||
|
- Debugging file persistence or virus scanning issues
|
||||||
|
|
||||||
|
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
|
||||||
|
|
||||||
## Security Implementation
|
## Security Implementation
|
||||||
|
|
||||||
### Cache Protection Middleware
|
### Cache Protection Middleware
|
||||||
|
|||||||
@@ -1,8 +1,17 @@
|
|||||||
from pydantic import BaseModel
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
|
import prisma.enums
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
from backend.data.model import UserTransaction
|
from backend.data.model import UserTransaction
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||||
|
|
||||||
|
|
||||||
class UserHistoryResponse(BaseModel):
|
class UserHistoryResponse(BaseModel):
|
||||||
"""Response model for listings with version history"""
|
"""Response model for listings with version history"""
|
||||||
@@ -14,3 +23,70 @@ class UserHistoryResponse(BaseModel):
|
|||||||
class AddUserCreditsResponse(BaseModel):
|
class AddUserCreditsResponse(BaseModel):
|
||||||
new_balance: int
|
new_balance: int
|
||||||
transaction_key: str
|
transaction_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class CreateInvitedUserRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvitedUserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
status: prisma.enums.InvitedUserStatus
|
||||||
|
auth_user_id: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
tally_understanding: Optional[dict[str, Any]] = None
|
||||||
|
tally_status: prisma.enums.TallyComputationStatus
|
||||||
|
tally_computed_at: Optional[datetime] = None
|
||||||
|
tally_error: Optional[str] = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||||
|
return cls.model_validate(record.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
class InvitedUsersResponse(BaseModel):
|
||||||
|
invited_users: list[InvitedUserResponse]
|
||||||
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class BulkInvitedUserRowResponse(BaseModel):
|
||||||
|
row_number: int
|
||||||
|
email: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||||
|
message: str
|
||||||
|
invited_user: Optional[InvitedUserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BulkInvitedUsersResponse(BaseModel):
|
||||||
|
created_count: int
|
||||||
|
skipped_count: int
|
||||||
|
error_count: int
|
||||||
|
results: list[BulkInvitedUserRowResponse]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||||
|
return cls(
|
||||||
|
created_count=result.created_count,
|
||||||
|
skipped_count=result.skipped_count,
|
||||||
|
error_count=result.error_count,
|
||||||
|
results=[
|
||||||
|
BulkInvitedUserRowResponse(
|
||||||
|
row_number=row.row_number,
|
||||||
|
email=row.email,
|
||||||
|
name=row.name,
|
||||||
|
status=row.status,
|
||||||
|
message=row.message,
|
||||||
|
invited_user=(
|
||||||
|
InvitedUserResponse.from_record(row.invited_user)
|
||||||
|
if row.invited_user is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for row in result.results
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||||
|
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||||
|
|
||||||
|
from backend.data.invited_user import (
|
||||||
|
bulk_create_invited_users_from_file,
|
||||||
|
create_invited_user,
|
||||||
|
list_invited_users,
|
||||||
|
retry_invited_user_tally,
|
||||||
|
revoke_invited_user,
|
||||||
|
)
|
||||||
|
from backend.data.tally import mask_email
|
||||||
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
from .model import (
|
||||||
|
BulkInvitedUsersResponse,
|
||||||
|
CreateInvitedUserRequest,
|
||||||
|
InvitedUserResponse,
|
||||||
|
InvitedUsersResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["users", "admin"],
|
||||||
|
dependencies=[Security(requires_admin_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/invited-users",
|
||||||
|
response_model=InvitedUsersResponse,
|
||||||
|
summary="List Invited Users",
|
||||||
|
)
|
||||||
|
async def get_invited_users(
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(50, ge=1, le=200),
|
||||||
|
) -> InvitedUsersResponse:
|
||||||
|
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||||
|
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||||
|
return InvitedUsersResponse(
|
||||||
|
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||||
|
pagination=Pagination(
|
||||||
|
total_items=total,
|
||||||
|
total_pages=max(1, math.ceil(total / page_size)),
|
||||||
|
current_page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users",
|
||||||
|
response_model=InvitedUserResponse,
|
||||||
|
summary="Create Invited User",
|
||||||
|
)
|
||||||
|
async def create_invited_user_route(
|
||||||
|
request: CreateInvitedUserRequest,
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> InvitedUserResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s creating invited user for %s",
|
||||||
|
admin_user_id,
|
||||||
|
mask_email(request.email),
|
||||||
|
)
|
||||||
|
invited_user = await create_invited_user(request.email, request.name)
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s created invited user %s",
|
||||||
|
admin_user_id,
|
||||||
|
invited_user.id,
|
||||||
|
)
|
||||||
|
return InvitedUserResponse.from_record(invited_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users/bulk",
|
||||||
|
response_model=BulkInvitedUsersResponse,
|
||||||
|
summary="Bulk Create Invited Users",
|
||||||
|
operation_id="postV2BulkCreateInvitedUsers",
|
||||||
|
)
|
||||||
|
async def bulk_create_invited_users_route(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> BulkInvitedUsersResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s bulk invited users from %s",
|
||||||
|
admin_user_id,
|
||||||
|
file.filename or "<unnamed>",
|
||||||
|
)
|
||||||
|
content = await file.read()
|
||||||
|
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||||
|
return BulkInvitedUsersResponse.from_result(result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users/{invited_user_id}/revoke",
|
||||||
|
response_model=InvitedUserResponse,
|
||||||
|
summary="Revoke Invited User",
|
||||||
|
)
|
||||||
|
async def revoke_invited_user_route(
|
||||||
|
invited_user_id: str,
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> InvitedUserResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||||
|
)
|
||||||
|
invited_user = await revoke_invited_user(invited_user_id)
|
||||||
|
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||||
|
return InvitedUserResponse.from_record(invited_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users/{invited_user_id}/retry-tally",
|
||||||
|
response_model=InvitedUserResponse,
|
||||||
|
summary="Retry Invited User Tally",
|
||||||
|
)
|
||||||
|
async def retry_invited_user_tally_route(
|
||||||
|
invited_user_id: str,
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> InvitedUserResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s retrying Tally seed for invited user %s",
|
||||||
|
admin_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
)
|
||||||
|
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s retried Tally seed for invited user %s",
|
||||||
|
admin_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
)
|
||||||
|
return InvitedUserResponse.from_record(invited_user)
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import prisma.enums
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
|
||||||
|
from backend.data.invited_user import (
|
||||||
|
BulkInvitedUserRowResult,
|
||||||
|
BulkInvitedUsersResult,
|
||||||
|
InvitedUserRecord,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .user_admin_routes import router as user_admin_router
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(user_admin_router)
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_admin_auth(mock_jwt_admin):
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_invited_user() -> InvitedUserRecord:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
return InvitedUserRecord(
|
||||||
|
id="invite-1",
|
||||||
|
email="invited@example.com",
|
||||||
|
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||||
|
auth_user_id=None,
|
||||||
|
name="Invited User",
|
||||||
|
tally_understanding=None,
|
||||||
|
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||||
|
tally_computed_at=None,
|
||||||
|
tally_error=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||||
|
return BulkInvitedUsersResult(
|
||||||
|
created_count=1,
|
||||||
|
skipped_count=1,
|
||||||
|
error_count=0,
|
||||||
|
results=[
|
||||||
|
BulkInvitedUserRowResult(
|
||||||
|
row_number=1,
|
||||||
|
email="invited@example.com",
|
||||||
|
name=None,
|
||||||
|
status="CREATED",
|
||||||
|
message="Invite created",
|
||||||
|
invited_user=_sample_invited_user(),
|
||||||
|
),
|
||||||
|
BulkInvitedUserRowResult(
|
||||||
|
row_number=2,
|
||||||
|
email="duplicate@example.com",
|
||||||
|
name=None,
|
||||||
|
status="SKIPPED",
|
||||||
|
message="An invited user with this email already exists",
|
||||||
|
invited_user=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_invited_users(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||||
|
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/invited-users")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["invited_users"]) == 1
|
||||||
|
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||||
|
assert data["invited_users"][0]["status"] == "INVITED"
|
||||||
|
assert data["pagination"]["total_items"] == 1
|
||||||
|
assert data["pagination"]["current_page"] == 1
|
||||||
|
assert data["pagination"]["page_size"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_invited_user(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||||
|
AsyncMock(return_value=_sample_invited_user()),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/invited-users",
|
||||||
|
json={"email": "invited@example.com", "name": "Invited User"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == "invited@example.com"
|
||||||
|
assert data["name"] == "Invited User"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_create_invited_users(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||||
|
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/invited-users/bulk",
|
||||||
|
files={
|
||||||
|
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["created_count"] == 1
|
||||||
|
assert data["skipped_count"] == 1
|
||||||
|
assert data["results"][0]["status"] == "CREATED"
|
||||||
|
assert data["results"][1]["status"] == "SKIPPED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_revoke_invited_user(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
revoked = _sample_invited_user().model_copy(
|
||||||
|
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||||
|
AsyncMock(return_value=revoked),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "REVOKED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_invited_user_tally(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
retried = _sample_invited_user().model_copy(
|
||||||
|
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||||
|
AsyncMock(return_value=retried),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["tally_status"] == "RUNNING"
|
||||||
@@ -28,6 +28,7 @@ from backend.copilot.model import (
|
|||||||
update_session_title,
|
update_session_title,
|
||||||
)
|
)
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||||
|
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||||
from backend.copilot.tools.models import (
|
from backend.copilot.tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
@@ -52,6 +53,8 @@ from backend.copilot.tools.models import (
|
|||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from backend.copilot.tracking import track_user_message
|
from backend.copilot.tracking import track_user_message
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
from backend.data.understanding import get_business_understanding
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -126,6 +129,7 @@ class SessionSummaryResponse(BaseModel):
|
|||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str
|
updated_at: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
|
is_processing: bool
|
||||||
|
|
||||||
|
|
||||||
class ListSessionsResponse(BaseModel):
|
class ListSessionsResponse(BaseModel):
|
||||||
@@ -184,6 +188,28 @@ async def list_sessions(
|
|||||||
"""
|
"""
|
||||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
|
# Batch-check Redis for active stream status on each session
|
||||||
|
processing_set: set[str] = set()
|
||||||
|
if sessions:
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
pipe = redis.pipeline(transaction=False)
|
||||||
|
for session in sessions:
|
||||||
|
pipe.hget(
|
||||||
|
f"{config.session_meta_prefix}{session.session_id}",
|
||||||
|
"status",
|
||||||
|
)
|
||||||
|
statuses = await pipe.execute()
|
||||||
|
processing_set = {
|
||||||
|
session.session_id
|
||||||
|
for session, st in zip(sessions, statuses)
|
||||||
|
if st == "running"
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||||
|
)
|
||||||
|
|
||||||
return ListSessionsResponse(
|
return ListSessionsResponse(
|
||||||
sessions=[
|
sessions=[
|
||||||
SessionSummaryResponse(
|
SessionSummaryResponse(
|
||||||
@@ -191,6 +217,7 @@ async def list_sessions(
|
|||||||
created_at=session.started_at.isoformat(),
|
created_at=session.started_at.isoformat(),
|
||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
title=session.title,
|
title=session.title,
|
||||||
|
is_processing=session.session_id in processing_set,
|
||||||
)
|
)
|
||||||
for session in sessions
|
for session in sessions
|
||||||
],
|
],
|
||||||
@@ -265,12 +292,12 @@ async def delete_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Best-effort cleanup of the E2B sandbox (if any).
|
# Best-effort cleanup of the E2B sandbox (if any).
|
||||||
config = ChatConfig()
|
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
||||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
e2b_cfg = ChatConfig()
|
||||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
if e2b_cfg.e2b_active:
|
||||||
|
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
||||||
try:
|
try:
|
||||||
await kill_sandbox(session_id, config.e2b_api_key)
|
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||||
@@ -827,6 +854,36 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Suggested Prompts ==========
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestedPromptsResponse(BaseModel):
|
||||||
|
"""Response model for user-specific suggested prompts."""
|
||||||
|
|
||||||
|
prompts: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/suggested-prompts",
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
)
|
||||||
|
async def get_suggested_prompts(
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> SuggestedPromptsResponse:
|
||||||
|
"""
|
||||||
|
Get LLM-generated suggested prompts for the authenticated user.
|
||||||
|
|
||||||
|
Returns personalized quick-action prompts based on the user's
|
||||||
|
business understanding. Returns an empty list if no custom prompts
|
||||||
|
are available.
|
||||||
|
"""
|
||||||
|
understanding = await get_business_understanding(user_id)
|
||||||
|
if understanding is None:
|
||||||
|
return SuggestedPromptsResponse(prompts=[])
|
||||||
|
|
||||||
|
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
# ========== Configuration ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import fastapi.testclient
|
import fastapi.testclient
|
||||||
@@ -249,3 +249,62 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
|||||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||||
assert call_kwargs["where"]["isDeleted"] is False
|
assert call_kwargs["where"]["isDeleted"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_get_business_understanding(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
*,
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
"""Mock get_business_understanding."""
|
||||||
|
return mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.get_business_understanding",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=return_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggested_prompts_returns_prompts(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""User with understanding and prompts gets them back."""
|
||||||
|
mock_understanding = MagicMock()
|
||||||
|
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||||
|
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||||
|
|
||||||
|
response = client.get("/suggested-prompts")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggested_prompts_no_understanding(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""User with no understanding gets empty list."""
|
||||||
|
_mock_get_business_understanding(mocker, return_value=None)
|
||||||
|
|
||||||
|
response = client.get("/suggested-prompts")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"prompts": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggested_prompts_empty_prompts(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""User with understanding but no prompts gets empty list."""
|
||||||
|
mock_understanding = MagicMock()
|
||||||
|
mock_understanding.suggested_prompts = []
|
||||||
|
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||||
|
|
||||||
|
response = client.get("/suggested-prompts")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"prompts": []}
|
||||||
|
|||||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.data.execution.get_node_executions"
|
"backend.api.features.executions.review.routes.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "test_node_123"
|
mock_node_exec.node_exec_id = "test_node_123"
|
||||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.data.execution.get_node_executions"
|
"backend.api.features.executions.review.routes.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
|||||||
|
|
||||||
# Mock get_node_executions to return batch node data
|
# Mock get_node_executions to return batch node data
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.data.execution.get_node_executions"
|
"backend.api.features.executions.review.routes.get_node_executions"
|
||||||
)
|
)
|
||||||
# Create mock node executions for each review
|
# Create mock node executions for each review
|
||||||
mock_node_execs = []
|
mock_node_execs = []
|
||||||
|
|||||||
@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
|
|||||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
|
from backend.copilot.constants import (
|
||||||
|
is_copilot_synthetic_id,
|
||||||
|
parse_node_id_from_exec_id,
|
||||||
|
)
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionContext,
|
ExecutionContext,
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
|
get_node_executions,
|
||||||
)
|
)
|
||||||
from backend.data.graph import get_graph_settings
|
from backend.data.graph import get_graph_settings
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
@@ -36,6 +41,38 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_node_ids(
|
||||||
|
node_exec_ids: list[str],
|
||||||
|
graph_exec_id: str,
|
||||||
|
is_copilot: bool,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Resolve node_exec_id -> node_id for auto-approval records.
|
||||||
|
|
||||||
|
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
||||||
|
Graph executions look up node_id from NodeExecution records.
|
||||||
|
"""
|
||||||
|
if not node_exec_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if is_copilot:
|
||||||
|
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
||||||
|
|
||||||
|
node_execs = await get_node_executions(
|
||||||
|
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||||
|
)
|
||||||
|
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for neid in node_exec_ids:
|
||||||
|
if neid in node_exec_map:
|
||||||
|
result[neid] = node_exec_map[neid]
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to resolve node_id for {neid}: Node execution not found."
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/pending",
|
"/pending",
|
||||||
summary="Get Pending Reviews",
|
summary="Get Pending Reviews",
|
||||||
@@ -110,14 +147,16 @@ async def list_pending_reviews_for_execution(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Verify user owns the graph execution before returning reviews
|
# Verify user owns the graph execution before returning reviews
|
||||||
graph_exec = await get_graph_execution_meta(
|
# (CoPilot synthetic IDs don't have graph execution records)
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
if not is_copilot_synthetic_id(graph_exec_id):
|
||||||
)
|
graph_exec = await get_graph_execution_meta(
|
||||||
if not graph_exec:
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||||
|
|
||||||
@@ -160,30 +199,26 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph_exec_id = next(iter(graph_exec_ids))
|
graph_exec_id = next(iter(graph_exec_ids))
|
||||||
|
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
||||||
|
|
||||||
# Validate execution status before processing reviews
|
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
||||||
graph_exec_meta = await get_graph_execution_meta(
|
if not is_copilot:
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
graph_exec_meta = await get_graph_execution_meta(
|
||||||
)
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
|
||||||
if not graph_exec_meta:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only allow processing reviews if execution is paused for review
|
|
||||||
# or incomplete (partial execution with some reviews already processed)
|
|
||||||
if graph_exec_meta.status not in (
|
|
||||||
ExecutionStatus.REVIEW,
|
|
||||||
ExecutionStatus.INCOMPLETE,
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
|
||||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
|
||||||
f"Current status: {graph_exec_meta.status}",
|
|
||||||
)
|
)
|
||||||
|
if not graph_exec_meta:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
if graph_exec_meta.status not in (
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
||||||
|
)
|
||||||
|
|
||||||
# Build review decisions map and track which reviews requested auto-approval
|
# Build review decisions map and track which reviews requested auto-approval
|
||||||
# Auto-approved reviews use original data (no modifications allowed)
|
# Auto-approved reviews use original data (no modifications allowed)
|
||||||
@@ -236,7 +271,7 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
return (node_id, False)
|
return (node_id, False)
|
||||||
|
|
||||||
# Collect node_exec_ids that need auto-approval
|
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
||||||
node_exec_ids_needing_auto_approval = [
|
node_exec_ids_needing_auto_approval = [
|
||||||
node_exec_id
|
node_exec_id
|
||||||
for node_exec_id, review_result in updated_reviews.items()
|
for node_exec_id, review_result in updated_reviews.items()
|
||||||
@@ -244,29 +279,16 @@ async def process_review_action(
|
|||||||
and auto_approve_requests.get(node_exec_id, False)
|
and auto_approve_requests.get(node_exec_id, False)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Batch-fetch node executions to get node_ids
|
node_id_map = await _resolve_node_ids(
|
||||||
|
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deduplicate by node_id — one auto-approval per node
|
||||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||||
if node_exec_ids_needing_auto_approval:
|
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||||
from backend.data.execution import get_node_executions
|
node_id = node_id_map.get(node_exec_id)
|
||||||
|
if node_id and node_id not in nodes_needing_auto_approval:
|
||||||
node_execs = await get_node_executions(
|
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
||||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
|
||||||
)
|
|
||||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
|
||||||
|
|
||||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
|
||||||
node_exec = node_exec_map.get(node_exec_id)
|
|
||||||
if node_exec:
|
|
||||||
review_result = updated_reviews[node_exec_id]
|
|
||||||
# Use the first approved review for this node (deduplicate by node_id)
|
|
||||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
|
||||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
|
||||||
f"Node execution not found. This may indicate a race condition "
|
|
||||||
f"or data inconsistency."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||||
auto_approval_results = await asyncio.gather(
|
auto_approval_results = await asyncio.gather(
|
||||||
@@ -281,13 +303,11 @@ async def process_review_action(
|
|||||||
auto_approval_failed_count = 0
|
auto_approval_failed_count = 0
|
||||||
for result in auto_approval_results:
|
for result in auto_approval_results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
# Unexpected exception during auto-approval creation
|
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unexpected exception during auto-approval creation: {result}"
|
f"Unexpected exception during auto-approval creation: {result}"
|
||||||
)
|
)
|
||||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||||
# Auto-approval creation failed (returned False)
|
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
|
|
||||||
# Count results
|
# Count results
|
||||||
@@ -302,22 +322,20 @@ async def process_review_action(
|
|||||||
if review.status == ReviewStatus.REJECTED
|
if review.status == ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
# Resume graph execution only for real graph executions (not CoPilot)
|
||||||
if updated_reviews:
|
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
||||||
|
if not is_copilot and updated_reviews:
|
||||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||||
|
|
||||||
if not still_has_pending:
|
if not still_has_pending:
|
||||||
# Get the graph_id from any processed review
|
|
||||||
first_review = next(iter(updated_reviews.values()))
|
first_review = next(iter(updated_reviews.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch user and settings to build complete execution context
|
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
settings = await get_graph_settings(
|
settings = await get_graph_settings(
|
||||||
user_id=user_id, graph_id=first_review.graph_id
|
user_id=user_id, graph_id=first_review.graph_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Preserve user's timezone preference when resuming execution
|
|
||||||
user_timezone = (
|
user_timezone = (
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -165,7 +165,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
owner_user_id: str
|
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -206,7 +205,9 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of recent executions with status, score, and summary",
|
description="List of recent executions with status, score, and summary",
|
||||||
)
|
)
|
||||||
can_access_graph: bool
|
can_access_graph: bool = pydantic.Field(
|
||||||
|
description="Indicates whether the same user owns the corresponding graph"
|
||||||
|
)
|
||||||
is_latest_version: bool
|
is_latest_version: bool
|
||||||
is_favorite: bool
|
is_favorite: bool
|
||||||
folder_id: str | None = None
|
folder_id: str | None = None
|
||||||
@@ -324,7 +325,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id=agent.id,
|
id=agent.id,
|
||||||
graph_id=agent.agentGraphId,
|
graph_id=agent.agentGraphId,
|
||||||
graph_version=agent.agentGraphVersion,
|
graph_version=agent.agentGraphVersion,
|
||||||
owner_user_id=agent.userId,
|
|
||||||
image_url=agent.imageUrl,
|
image_url=agent.imageUrl,
|
||||||
creator_name=creator_name,
|
creator_name=creator_name,
|
||||||
creator_image_url=creator_image_url,
|
creator_image_url=creator_image_url,
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -67,7 +66,6 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-2",
|
id="test-agent-2",
|
||||||
graph_id="test-agent-2",
|
graph_id="test-agent-2",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 2",
|
name="Test Agent 2",
|
||||||
description="Test Description 2",
|
description="Test Description 2",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -131,7 +129,6 @@ async def test_get_favorite_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Favorite Agent 1",
|
name="Favorite Agent 1",
|
||||||
description="Test Favorite Description 1",
|
description="Test Favorite Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -184,7 +181,6 @@ def test_add_agent_to_library_success(
|
|||||||
id="test-library-agent-id",
|
id="test-library-agent-id",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|||||||
from backend.data.model import OAuth2Credentials
|
from backend.data.model import OAuth2Credentials
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -80,7 +80,7 @@ async def discover_tools(
|
|||||||
"""
|
"""
|
||||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||||
try:
|
try:
|
||||||
await validate_url(request.server_url, trusted_origins=[])
|
await validate_url_host(request.server_url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||||
|
|
||||||
@@ -167,7 +167,7 @@ async def mcp_oauth_login(
|
|||||||
"""
|
"""
|
||||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||||
try:
|
try:
|
||||||
await validate_url(request.server_url, trusted_origins=[])
|
await validate_url_host(request.server_url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ async def mcp_oauth_login(
|
|||||||
|
|
||||||
# Validate the auth server URL from metadata to prevent SSRF.
|
# Validate the auth server URL from metadata to prevent SSRF.
|
||||||
try:
|
try:
|
||||||
await validate_url(auth_server_url, trusted_origins=[])
|
await validate_url_host(auth_server_url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise fastapi.HTTPException(
|
raise fastapi.HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@@ -234,7 +234,7 @@ async def mcp_oauth_login(
|
|||||||
if registration_endpoint:
|
if registration_endpoint:
|
||||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||||
try:
|
try:
|
||||||
await validate_url(registration_endpoint, trusted_origins=[])
|
await validate_url_host(registration_endpoint)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Skip registration, fall back to default client_id
|
pass # Skip registration, fall back to default client_id
|
||||||
else:
|
else:
|
||||||
@@ -429,7 +429,7 @@ async def mcp_store_token(
|
|||||||
|
|
||||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||||
try:
|
try:
|
||||||
await validate_url(request.server_url, trusted_origins=[])
|
await validate_url_host(request.server_url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,9 @@ async def client():
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _bypass_ssrf_validation():
|
def _bypass_ssrf_validation():
|
||||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.mcp.routes.validate_url",
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
@@ -521,12 +521,12 @@ class TestStoreToken:
|
|||||||
|
|
||||||
|
|
||||||
class TestSSRFValidation:
|
class TestSSRFValidation:
|
||||||
"""Verify that validate_url is enforced on all endpoints."""
|
"""Verify that validate_url_host is enforced on all endpoints."""
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_discover_tools_ssrf_blocked(self, client):
|
async def test_discover_tools_ssrf_blocked(self, client):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.mcp.routes.validate_url",
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=ValueError("blocked loopback"),
|
side_effect=ValueError("blocked loopback"),
|
||||||
):
|
):
|
||||||
@@ -541,7 +541,7 @@ class TestSSRFValidation:
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_oauth_login_ssrf_blocked(self, client):
|
async def test_oauth_login_ssrf_blocked(self, client):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.mcp.routes.validate_url",
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=ValueError("blocked private IP"),
|
side_effect=ValueError("blocked private IP"),
|
||||||
):
|
):
|
||||||
@@ -556,7 +556,7 @@ class TestSSRFValidation:
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_store_token_ssrf_blocked(self, client):
|
async def test_store_token_ssrf_blocked(self, client):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.mcp.routes.validate_url",
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=ValueError("blocked loopback"),
|
side_effect=ValueError("blocked loopback"),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from backend.data.credit import (
|
|||||||
set_auto_top_up,
|
set_auto_top_up,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
|
from backend.data.invited_user import get_or_activate_user
|
||||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||||
from backend.data.onboarding import (
|
from backend.data.onboarding import (
|
||||||
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
|
|||||||
update_user_onboarding,
|
update_user_onboarding,
|
||||||
)
|
)
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_or_create_user,
|
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_email,
|
update_user_email,
|
||||||
@@ -136,12 +136,10 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||||
user = await get_or_create_user(user_data)
|
user = await get_or_activate_user(user_data)
|
||||||
|
|
||||||
# Fire-and-forget: populate business understanding from Tally form.
|
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||||
# We use created_at proximity instead of an is_new flag because
|
# not produce a stored result before first activation.
|
||||||
# get_or_create_user is cached — a separate is_new return value would be
|
|
||||||
# unreliable on repeated calls within the cache TTL.
|
|
||||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||||
if age_seconds < 30:
|
if age_seconds < 30:
|
||||||
try:
|
try:
|
||||||
@@ -165,7 +163,8 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def update_user_email_route(
|
async def update_user_email_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
email: str = Body(...),
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
await update_user_email(user_id, email)
|
await update_user_email(user_id, email)
|
||||||
|
|
||||||
@@ -179,10 +178,16 @@ async def update_user_email_route(
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def get_user_timezone_route(
|
async def get_user_timezone_route(
|
||||||
user_data: dict = Security(get_jwt_payload),
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> TimezoneResponse:
|
) -> TimezoneResponse:
|
||||||
"""Get user timezone setting."""
|
"""Get user timezone setting."""
|
||||||
user = await get_or_create_user(user_data)
|
try:
|
||||||
|
user = await get_user_by_id(user_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found. Please complete activation via /auth/user first.",
|
||||||
|
)
|
||||||
return TimezoneResponse(timezone=user.timezone)
|
return TimezoneResponse(timezone=user.timezone)
|
||||||
|
|
||||||
|
|
||||||
@@ -193,7 +198,8 @@ async def get_user_timezone_route(
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def update_user_timezone_route(
|
async def update_user_timezone_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
request: UpdateTimezoneRequest,
|
||||||
) -> TimezoneResponse:
|
) -> TimezoneResponse:
|
||||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||||
user = await update_user_timezone(user_id, str(request.timezone))
|
user = await update_user_timezone(user_id, str(request.timezone))
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def test_get_or_create_user_route(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.api.features.v1.get_or_create_user",
|
"backend.api.features.v1.get_or_activate_user",
|
||||||
return_value=mock_user,
|
return_value=mock_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -94,3 +94,8 @@ class NotificationPayload(pydantic.BaseModel):
|
|||||||
|
|
||||||
class OnboardingNotificationPayload(NotificationPayload):
|
class OnboardingNotificationPayload(NotificationPayload):
|
||||||
step: OnboardingStep | None
|
step: OnboardingStep | None
|
||||||
|
|
||||||
|
|
||||||
|
class CopilotCompletionPayload(NotificationPayload):
|
||||||
|
session_id: str
|
||||||
|
status: Literal["completed", "failed"]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
|
|||||||
import backend.api.features.admin.credit_admin_routes
|
import backend.api.features.admin.credit_admin_routes
|
||||||
import backend.api.features.admin.execution_analytics_routes
|
import backend.api.features.admin.execution_analytics_routes
|
||||||
import backend.api.features.admin.store_admin_routes
|
import backend.api.features.admin.store_admin_routes
|
||||||
|
import backend.api.features.admin.user_admin_routes
|
||||||
import backend.api.features.builder
|
import backend.api.features.builder
|
||||||
import backend.api.features.builder.routes
|
import backend.api.features.builder.routes
|
||||||
import backend.api.features.chat.routes as chat_routes
|
import backend.api.features.chat.routes as chat_routes
|
||||||
@@ -311,6 +312,11 @@ app.include_router(
|
|||||||
tags=["v2", "admin"],
|
tags=["v2", "admin"],
|
||||||
prefix="/api/executions",
|
prefix="/api/executions",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
backend.api.features.admin.user_admin_routes.router,
|
||||||
|
tags=["v2", "admin"],
|
||||||
|
prefix="/api/users",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.executions.review.routes.router,
|
backend.api.features.executions.review.routes.router,
|
||||||
tags=["v2", "executions", "review"],
|
tags=["v2", "executions", "review"],
|
||||||
|
|||||||
@@ -418,6 +418,8 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
|
_optimized_description: ClassVar[str | None] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str = "",
|
id: str = "",
|
||||||
@@ -470,6 +472,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
self.block_type = block_type
|
self.block_type = block_type
|
||||||
self.webhook_config = webhook_config
|
self.webhook_config = webhook_config
|
||||||
self.is_sensitive_action = is_sensitive_action
|
self.is_sensitive_action = is_sensitive_action
|
||||||
|
# Read from ClassVar set by initialize_blocks()
|
||||||
|
self.optimized_description: str | None = type(self)._optimized_description
|
||||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||||
|
|
||||||
if self.webhook_config:
|
if self.webhook_config:
|
||||||
@@ -620,6 +624,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: "ExecutionContext",
|
execution_context: "ExecutionContext",
|
||||||
|
is_graph_execution: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, BlockInput]:
|
) -> tuple[bool, BlockInput]:
|
||||||
"""
|
"""
|
||||||
@@ -648,6 +653,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
block_name=self.name,
|
block_name=self.name,
|
||||||
editable=True,
|
editable=True,
|
||||||
|
is_graph_execution=is_graph_execution,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decision is None:
|
if decision is None:
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
|
|||||||
output_schema=PrintToConsoleBlock.Output,
|
output_schema=PrintToConsoleBlock.Output,
|
||||||
test_input={"text": "Hello, World!"},
|
test_input={"text": "Hello, World!"},
|
||||||
is_sensitive_action=True,
|
is_sensitive_action=True,
|
||||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
disabled=True,
|
||||||
test_output=[
|
test_output=[
|
||||||
("output", "Hello, World!"),
|
("output", "Hello, World!"),
|
||||||
("status", "printed"),
|
("status", "printed"),
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
|
|||||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||||
|
|
||||||
# Execute the code
|
# Execute the code
|
||||||
execution = await sandbox.run_code(
|
execution = await sandbox.run_code( # type: ignore[attr-defined]
|
||||||
code,
|
code,
|
||||||
language=language.value,
|
language=language.value,
|
||||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ class SendEmailBlock(Block):
|
|||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("status", "Email sent successfully")],
|
test_output=[("status", "Email sent successfully")],
|
||||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||||
|
is_sensitive_action=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
3
autogpt_platform/backend/backend/blocks/github/_utils.py
Normal file
3
autogpt_platform/backend/backend/blocks/github/_utils.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
def github_repo_path(repo_url: str) -> str:
|
||||||
|
"""Extract 'owner/repo' from a GitHub repository URL."""
|
||||||
|
return repo_url.replace("https://github.com/", "")
|
||||||
408
autogpt_platform/backend/backend/blocks/github/commits.py
Normal file
408
autogpt_platform/backend/backend/blocks/github/commits.py
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import StrEnum
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import parse_data_uri, resolve_media_content
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
from ._api import get_api
|
||||||
|
from ._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
GithubCredentials,
|
||||||
|
GithubCredentialsField,
|
||||||
|
GithubCredentialsInput,
|
||||||
|
)
|
||||||
|
from ._utils import github_repo_path
|
||||||
|
|
||||||
|
|
||||||
|
class GithubListCommitsBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch name to list commits from",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
per_page: int = SchemaField(
|
||||||
|
description="Number of commits to return (max 100)",
|
||||||
|
default=30,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
)
|
||||||
|
page: int = SchemaField(
|
||||||
|
description="Page number for pagination",
|
||||||
|
default=1,
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class CommitItem(TypedDict):
|
||||||
|
sha: str
|
||||||
|
message: str
|
||||||
|
author: str
|
||||||
|
date: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
commit: CommitItem = SchemaField(
|
||||||
|
title="Commit", description="A commit with its details"
|
||||||
|
)
|
||||||
|
commits: list[CommitItem] = SchemaField(
|
||||||
|
description="List of commits with their details"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if listing commits failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
||||||
|
description="This block lists commits on a branch in a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubListCommitsBlock.Input,
|
||||||
|
output_schema=GithubListCommitsBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"branch": "main",
|
||||||
|
"per_page": 30,
|
||||||
|
"page": 1,
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"commits",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sha": "abc123",
|
||||||
|
"message": "Initial commit",
|
||||||
|
"author": "octocat",
|
||||||
|
"date": "2024-01-01T00:00:00Z",
|
||||||
|
"url": "https://github.com/owner/repo/commit/abc123",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"commit",
|
||||||
|
{
|
||||||
|
"sha": "abc123",
|
||||||
|
"message": "Initial commit",
|
||||||
|
"author": "octocat",
|
||||||
|
"date": "2024-01-01T00:00:00Z",
|
||||||
|
"url": "https://github.com/owner/repo/commit/abc123",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"list_commits": lambda *args, **kwargs: [
|
||||||
|
{
|
||||||
|
"sha": "abc123",
|
||||||
|
"message": "Initial commit",
|
||||||
|
"author": "octocat",
|
||||||
|
"date": "2024-01-01T00:00:00Z",
|
||||||
|
"url": "https://github.com/owner/repo/commit/abc123",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def list_commits(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
branch: str,
|
||||||
|
per_page: int,
|
||||||
|
page: int,
|
||||||
|
) -> list[Output.CommitItem]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
commits_url = repo_url + "/commits"
|
||||||
|
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
||||||
|
response = await api.get(commits_url, params=params)
|
||||||
|
data = response.json()
|
||||||
|
repo_path = github_repo_path(repo_url)
|
||||||
|
return [
|
||||||
|
GithubListCommitsBlock.Output.CommitItem(
|
||||||
|
sha=c["sha"],
|
||||||
|
message=c["commit"]["message"],
|
||||||
|
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
||||||
|
date=(c["commit"].get("author") or {}).get("date", ""),
|
||||||
|
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
||||||
|
)
|
||||||
|
for c in data
|
||||||
|
]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
commits = await self.list_commits(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.per_page,
|
||||||
|
input_data.page,
|
||||||
|
)
|
||||||
|
yield "commits", commits
|
||||||
|
for commit in commits:
|
||||||
|
yield "commit", commit
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class FileOperation(StrEnum):
|
||||||
|
"""File operations for GithubMultiFileCommitBlock.
|
||||||
|
|
||||||
|
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
||||||
|
between creation and update — the blob is placed at the given path regardless
|
||||||
|
of whether a file already exists there).
|
||||||
|
|
||||||
|
DELETE removes a file from the tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
UPSERT = "upsert"
|
||||||
|
DELETE = "delete"
|
||||||
|
|
||||||
|
|
||||||
|
class FileOperationInput(TypedDict):
|
||||||
|
path: str
|
||||||
|
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||||
|
content: MediaFileType
|
||||||
|
operation: FileOperation
|
||||||
|
|
||||||
|
|
||||||
|
class GithubMultiFileCommitBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch to commit to",
|
||||||
|
placeholder="feature-branch",
|
||||||
|
)
|
||||||
|
commit_message: str = SchemaField(
|
||||||
|
description="Commit message",
|
||||||
|
placeholder="Add new feature",
|
||||||
|
)
|
||||||
|
files: list[FileOperationInput] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"List of file operations. Each item has: "
|
||||||
|
"'path' (file path), 'content' (file content, ignored for delete), "
|
||||||
|
"'operation' (upsert/delete)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
sha: str = SchemaField(description="SHA of the new commit")
|
||||||
|
url: str = SchemaField(description="URL of the new commit")
|
||||||
|
error: str = SchemaField(description="Error message if the commit failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="389eee51-a95e-4230-9bed-92167a327802",
|
||||||
|
description=(
|
||||||
|
"This block creates a single commit with multiple file "
|
||||||
|
"upsert/delete operations using the Git Trees API."
|
||||||
|
),
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubMultiFileCommitBlock.Input,
|
||||||
|
output_schema=GithubMultiFileCommitBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"branch": "feature",
|
||||||
|
"commit_message": "Add files",
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"path": "src/new.py",
|
||||||
|
"content": "print('hello')",
|
||||||
|
"operation": "upsert",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "src/old.py",
|
||||||
|
"content": "",
|
||||||
|
"operation": "delete",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("sha", "newcommitsha"),
|
||||||
|
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"multi_file_commit": lambda *args, **kwargs: (
|
||||||
|
"newcommitsha",
|
||||||
|
"https://github.com/owner/repo/commit/newcommitsha",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def multi_file_commit(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
branch: str,
|
||||||
|
commit_message: str,
|
||||||
|
files: list[FileOperationInput],
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
safe_branch = quote(branch, safe="")
|
||||||
|
|
||||||
|
# 1. Get the latest commit SHA for the branch
|
||||||
|
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
||||||
|
response = await api.get(ref_url)
|
||||||
|
ref_data = response.json()
|
||||||
|
latest_commit_sha = ref_data["object"]["sha"]
|
||||||
|
|
||||||
|
# 2. Get the tree SHA of the latest commit
|
||||||
|
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
||||||
|
response = await api.get(commit_url)
|
||||||
|
commit_data = response.json()
|
||||||
|
base_tree_sha = commit_data["tree"]["sha"]
|
||||||
|
|
||||||
|
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||||
|
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||||
|
blob_url = repo_url + "/git/blobs"
|
||||||
|
blob_response = await api.post(
|
||||||
|
blob_url,
|
||||||
|
json={"content": content, "encoding": encoding},
|
||||||
|
)
|
||||||
|
return blob_response.json()["sha"]
|
||||||
|
|
||||||
|
tree_entries: list[dict] = []
|
||||||
|
upsert_files = []
|
||||||
|
for file_op in files:
|
||||||
|
path = file_op["path"]
|
||||||
|
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||||
|
|
||||||
|
if operation == FileOperation.DELETE:
|
||||||
|
tree_entries.append(
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"mode": "100644",
|
||||||
|
"type": "blob",
|
||||||
|
"sha": None, # null SHA = delete
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
upsert_files.append((path, file_op.get("content", "")))
|
||||||
|
|
||||||
|
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||||
|
# are sent as base64 blobs to preserve binary content.
|
||||||
|
if upsert_files:
|
||||||
|
|
||||||
|
async def _make_blob(content: str) -> str:
|
||||||
|
parsed = parse_data_uri(content)
|
||||||
|
if parsed is not None:
|
||||||
|
_, b64_payload = parsed
|
||||||
|
return await _create_blob(b64_payload, encoding="base64")
|
||||||
|
return await _create_blob(content)
|
||||||
|
|
||||||
|
blob_shas = await asyncio.gather(
|
||||||
|
*[_make_blob(content) for _, content in upsert_files]
|
||||||
|
)
|
||||||
|
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||||
|
tree_entries.append(
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"mode": "100644",
|
||||||
|
"type": "blob",
|
||||||
|
"sha": blob_sha,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Create a new tree
|
||||||
|
tree_url = repo_url + "/git/trees"
|
||||||
|
tree_response = await api.post(
|
||||||
|
tree_url,
|
||||||
|
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
||||||
|
)
|
||||||
|
new_tree_sha = tree_response.json()["sha"]
|
||||||
|
|
||||||
|
# 5. Create a new commit
|
||||||
|
new_commit_url = repo_url + "/git/commits"
|
||||||
|
commit_response = await api.post(
|
||||||
|
new_commit_url,
|
||||||
|
json={
|
||||||
|
"message": commit_message,
|
||||||
|
"tree": new_tree_sha,
|
||||||
|
"parents": [latest_commit_sha],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
new_commit_sha = commit_response.json()["sha"]
|
||||||
|
|
||||||
|
# 6. Update the branch reference
|
||||||
|
try:
|
||||||
|
await api.patch(
|
||||||
|
ref_url,
|
||||||
|
json={"sha": new_commit_sha},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Commit {new_commit_sha} was created but failed to update "
|
||||||
|
f"ref heads/{branch}: {e}. "
|
||||||
|
f"You can recover by manually updating the branch to {new_commit_sha}."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
repo_path = github_repo_path(repo_url)
|
||||||
|
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
||||||
|
return new_commit_sha, commit_web_url
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
# Resolve media references (workspace://, data:, URLs) to data
|
||||||
|
# URIs so _make_blob can send binary content correctly.
|
||||||
|
resolved_files: list[FileOperationInput] = []
|
||||||
|
for file_op in input_data.files:
|
||||||
|
content = file_op.get("content", "")
|
||||||
|
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||||
|
if operation != FileOperation.DELETE:
|
||||||
|
content = await resolve_media_content(
|
||||||
|
MediaFileType(content),
|
||||||
|
execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
resolved_files.append(
|
||||||
|
FileOperationInput(
|
||||||
|
path=file_op["path"],
|
||||||
|
content=MediaFileType(content),
|
||||||
|
operation=operation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sha, url = await self.multi_file_commit(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.commit_message,
|
||||||
|
resolved_files,
|
||||||
|
)
|
||||||
|
yield "sha", sha
|
||||||
|
yield "url", url
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
@@ -20,6 +21,8 @@ from ._auth import (
|
|||||||
GithubCredentialsInput,
|
GithubCredentialsInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MergeMethod = Literal["merge", "squash", "rebase"]
|
||||||
|
|
||||||
|
|
||||||
class GithubListPullRequestsBlock(Block):
|
class GithubListPullRequestsBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
@@ -558,12 +561,109 @@ class GithubListPRReviewersBlock(Block):
|
|||||||
yield "reviewer", reviewer
|
yield "reviewer", reviewer
|
||||||
|
|
||||||
|
|
||||||
|
class GithubMergePullRequestBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
pr_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub pull request",
|
||||||
|
placeholder="https://github.com/owner/repo/pull/1",
|
||||||
|
)
|
||||||
|
merge_method: MergeMethod = SchemaField(
|
||||||
|
description="Merge method to use: merge, squash, or rebase",
|
||||||
|
default="merge",
|
||||||
|
)
|
||||||
|
commit_title: str = SchemaField(
|
||||||
|
description="Title for the merge commit (optional, used for merge and squash)",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
commit_message: str = SchemaField(
|
||||||
|
description="Message for the merge commit (optional, used for merge and squash)",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
sha: str = SchemaField(description="SHA of the merge commit")
|
||||||
|
merged: bool = SchemaField(description="Whether the PR was merged")
|
||||||
|
message: str = SchemaField(description="Merge status message")
|
||||||
|
error: str = SchemaField(description="Error message if the merge failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
||||||
|
description="This block merges a pull request using merge, squash, or rebase.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubMergePullRequestBlock.Input,
|
||||||
|
output_schema=GithubMergePullRequestBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||||
|
"merge_method": "squash",
|
||||||
|
"commit_title": "",
|
||||||
|
"commit_message": "",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("sha", "abc123"),
|
||||||
|
("merged", True),
|
||||||
|
("message", "Pull Request successfully merged"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"merge_pr": lambda *args, **kwargs: (
|
||||||
|
"abc123",
|
||||||
|
True,
|
||||||
|
"Pull Request successfully merged",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
is_sensitive_action=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def merge_pr(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
pr_url: str,
|
||||||
|
merge_method: MergeMethod,
|
||||||
|
commit_title: str,
|
||||||
|
commit_message: str,
|
||||||
|
) -> tuple[str, bool, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
||||||
|
data: dict[str, str] = {"merge_method": merge_method}
|
||||||
|
if commit_title:
|
||||||
|
data["commit_title"] = commit_title
|
||||||
|
if commit_message:
|
||||||
|
data["commit_message"] = commit_message
|
||||||
|
response = await api.put(merge_url, json=data)
|
||||||
|
result = response.json()
|
||||||
|
return result["sha"], result["merged"], result["message"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
sha, merged, message = await self.merge_pr(
|
||||||
|
credentials,
|
||||||
|
input_data.pr_url,
|
||||||
|
input_data.merge_method,
|
||||||
|
input_data.commit_title,
|
||||||
|
input_data.commit_message,
|
||||||
|
)
|
||||||
|
yield "sha", sha
|
||||||
|
yield "merged", merged
|
||||||
|
yield "message", message
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||||
# Pattern to capture the base repository URL and the pull request number
|
# Pattern to capture the base repository URL and the pull request number
|
||||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||||
match = re.match(pattern, pr_url)
|
match = re.match(pattern, pr_url)
|
||||||
if not match:
|
if not match:
|
||||||
return pr_url
|
return pr_url
|
||||||
|
|
||||||
base_url, pr_number = match.groups()
|
scheme, base_url, pr_number = match.groups()
|
||||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import base64
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
@@ -19,6 +17,7 @@ from ._auth import (
|
|||||||
GithubCredentialsField,
|
GithubCredentialsField,
|
||||||
GithubCredentialsInput,
|
GithubCredentialsInput,
|
||||||
)
|
)
|
||||||
|
from ._utils import github_repo_path
|
||||||
|
|
||||||
|
|
||||||
class GithubListTagsBlock(Block):
|
class GithubListTagsBlock(Block):
|
||||||
@@ -89,7 +88,7 @@ class GithubListTagsBlock(Block):
|
|||||||
tags_url = repo_url + "/tags"
|
tags_url = repo_url + "/tags"
|
||||||
response = await api.get(tags_url)
|
response = await api.get(tags_url)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
repo_path = repo_url.replace("https://github.com/", "")
|
repo_path = github_repo_path(repo_url)
|
||||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||||
{
|
{
|
||||||
"name": tag["name"],
|
"name": tag["name"],
|
||||||
@@ -115,101 +114,6 @@ class GithubListTagsBlock(Block):
|
|||||||
yield "tag", tag
|
yield "tag", tag
|
||||||
|
|
||||||
|
|
||||||
class GithubListBranchesBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class BranchItem(TypedDict):
|
|
||||||
name: str
|
|
||||||
url: str
|
|
||||||
|
|
||||||
branch: BranchItem = SchemaField(
|
|
||||||
title="Branch",
|
|
||||||
description="Branches with their name and file tree browser URL",
|
|
||||||
)
|
|
||||||
branches: list[BranchItem] = SchemaField(
|
|
||||||
description="List of branches with their name and file tree browser URL"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
|
||||||
description="This block lists all branches for a specified GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubListBranchesBlock.Input,
|
|
||||||
output_schema=GithubListBranchesBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"branches",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "main",
|
|
||||||
"url": "https://github.com/owner/repo/tree/main",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"branch",
|
|
||||||
{
|
|
||||||
"name": "main",
|
|
||||||
"url": "https://github.com/owner/repo/tree/main",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"list_branches": lambda *args, **kwargs: [
|
|
||||||
{
|
|
||||||
"name": "main",
|
|
||||||
"url": "https://github.com/owner/repo/tree/main",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def list_branches(
|
|
||||||
credentials: GithubCredentials, repo_url: str
|
|
||||||
) -> list[Output.BranchItem]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
branches_url = repo_url + "/branches"
|
|
||||||
response = await api.get(branches_url)
|
|
||||||
data = response.json()
|
|
||||||
repo_path = repo_url.replace("https://github.com/", "")
|
|
||||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
|
||||||
{
|
|
||||||
"name": branch["name"],
|
|
||||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
|
||||||
}
|
|
||||||
for branch in data
|
|
||||||
]
|
|
||||||
return branches
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
branches = await self.list_branches(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
)
|
|
||||||
yield "branches", branches
|
|
||||||
for branch in branches:
|
|
||||||
yield "branch", branch
|
|
||||||
|
|
||||||
|
|
||||||
class GithubListDiscussionsBlock(Block):
|
class GithubListDiscussionsBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
@@ -283,7 +187,7 @@ class GithubListDiscussionsBlock(Block):
|
|||||||
) -> list[Output.DiscussionItem]:
|
) -> list[Output.DiscussionItem]:
|
||||||
api = get_api(credentials)
|
api = get_api(credentials)
|
||||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||||
repo_path = repo_url.replace("https://github.com/", "")
|
repo_path = github_repo_path(repo_url)
|
||||||
owner, repo = repo_path.split("/")
|
owner, repo = repo_path.split("/")
|
||||||
query = """
|
query = """
|
||||||
query($owner: String!, $repo: String!, $num: Int!) {
|
query($owner: String!, $repo: String!, $num: Int!) {
|
||||||
@@ -416,564 +320,6 @@ class GithubListReleasesBlock(Block):
|
|||||||
yield "release", release
|
yield "release", release
|
||||||
|
|
||||||
|
|
||||||
class GithubReadFileBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
file_path: str = SchemaField(
|
|
||||||
description="Path to the file in the repository",
|
|
||||||
placeholder="path/to/file",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch to read from",
|
|
||||||
placeholder="branch_name",
|
|
||||||
default="master",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
text_content: str = SchemaField(
|
|
||||||
description="Content of the file (decoded as UTF-8 text)"
|
|
||||||
)
|
|
||||||
raw_content: str = SchemaField(
|
|
||||||
description="Raw base64-encoded content of the file"
|
|
||||||
)
|
|
||||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
|
||||||
description="This block reads the content of a specified file from a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubReadFileBlock.Input,
|
|
||||||
output_schema=GithubReadFileBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"file_path": "path/to/file",
|
|
||||||
"branch": "master",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("raw_content", "RmlsZSBjb250ZW50"),
|
|
||||||
("text_content", "File content"),
|
|
||||||
("size", 13),
|
|
||||||
],
|
|
||||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def read_file(
|
|
||||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
|
||||||
) -> tuple[str, int]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
|
||||||
response = await api.get(content_url)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if isinstance(data, list):
|
|
||||||
# Multiple entries of different types exist at this path
|
|
||||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
|
||||||
raise TypeError("Not a file")
|
|
||||||
data = file
|
|
||||||
|
|
||||||
if data["type"] != "file":
|
|
||||||
raise TypeError("Not a file")
|
|
||||||
|
|
||||||
return data["content"], data["size"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
content, size = await self.read_file(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.file_path,
|
|
||||||
input_data.branch,
|
|
||||||
)
|
|
||||||
yield "raw_content", content
|
|
||||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
|
||||||
yield "size", size
|
|
||||||
|
|
||||||
|
|
||||||
class GithubReadFolderBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
folder_path: str = SchemaField(
|
|
||||||
description="Path to the folder in the repository",
|
|
||||||
placeholder="path/to/folder",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch name to read from (defaults to master)",
|
|
||||||
placeholder="branch_name",
|
|
||||||
default="master",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class DirEntry(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
|
|
||||||
class FileEntry(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
size: int
|
|
||||||
|
|
||||||
file: FileEntry = SchemaField(description="Files in the folder")
|
|
||||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if reading the folder failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
|
||||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubReadFolderBlock.Input,
|
|
||||||
output_schema=GithubReadFolderBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"folder_path": "path/to/folder",
|
|
||||||
"branch": "master",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"file",
|
|
||||||
{
|
|
||||||
"name": "file1.txt",
|
|
||||||
"path": "path/to/folder/file1.txt",
|
|
||||||
"size": 1337,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"read_folder": lambda *args, **kwargs: (
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "file1.txt",
|
|
||||||
"path": "path/to/folder/file1.txt",
|
|
||||||
"size": 1337,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def read_folder(
|
|
||||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
|
||||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
|
||||||
response = await api.get(contents_url)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not isinstance(data, list):
|
|
||||||
raise TypeError("Not a folder")
|
|
||||||
|
|
||||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
|
||||||
GithubReadFolderBlock.Output.FileEntry(
|
|
||||||
name=entry["name"],
|
|
||||||
path=entry["path"],
|
|
||||||
size=entry["size"],
|
|
||||||
)
|
|
||||||
for entry in data
|
|
||||||
if entry["type"] == "file"
|
|
||||||
]
|
|
||||||
|
|
||||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
|
||||||
GithubReadFolderBlock.Output.DirEntry(
|
|
||||||
name=entry["name"],
|
|
||||||
path=entry["path"],
|
|
||||||
)
|
|
||||||
for entry in data
|
|
||||||
if entry["type"] == "dir"
|
|
||||||
]
|
|
||||||
|
|
||||||
return files, dirs
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
files, dirs = await self.read_folder(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.folder_path.lstrip("/"),
|
|
||||||
input_data.branch,
|
|
||||||
)
|
|
||||||
for file in files:
|
|
||||||
yield "file", file
|
|
||||||
for dir in dirs:
|
|
||||||
yield "dir", dir
|
|
||||||
|
|
||||||
|
|
||||||
class GithubMakeBranchBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
new_branch: str = SchemaField(
|
|
||||||
description="Name of the new branch",
|
|
||||||
placeholder="new_branch_name",
|
|
||||||
)
|
|
||||||
source_branch: str = SchemaField(
|
|
||||||
description="Name of the source branch",
|
|
||||||
placeholder="source_branch_name",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
status: str = SchemaField(description="Status of the branch creation operation")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the branch creation failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
|
||||||
description="This block creates a new branch from a specified source branch.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubMakeBranchBlock.Input,
|
|
||||||
output_schema=GithubMakeBranchBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"new_branch": "new_branch_name",
|
|
||||||
"source_branch": "source_branch_name",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("status", "Branch created successfully")],
|
|
||||||
test_mock={
|
|
||||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create_branch(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
new_branch: str,
|
|
||||||
source_branch: str,
|
|
||||||
) -> str:
|
|
||||||
api = get_api(credentials)
|
|
||||||
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
|
||||||
response = await api.get(ref_url)
|
|
||||||
data = response.json()
|
|
||||||
sha = data["object"]["sha"]
|
|
||||||
|
|
||||||
# Create the new branch
|
|
||||||
new_ref_url = repo_url + "/git/refs"
|
|
||||||
data = {
|
|
||||||
"ref": f"refs/heads/{new_branch}",
|
|
||||||
"sha": sha,
|
|
||||||
}
|
|
||||||
response = await api.post(new_ref_url, json=data)
|
|
||||||
return "Branch created successfully"
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
status = await self.create_branch(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.new_branch,
|
|
||||||
input_data.source_branch,
|
|
||||||
)
|
|
||||||
yield "status", status
|
|
||||||
|
|
||||||
|
|
||||||
class GithubDeleteBranchBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Name of the branch to delete",
|
|
||||||
placeholder="branch_name",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the branch deletion failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
|
||||||
description="This block deletes a specified branch.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubDeleteBranchBlock.Input,
|
|
||||||
output_schema=GithubDeleteBranchBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"branch": "branch_name",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("status", "Branch deleted successfully")],
|
|
||||||
test_mock={
|
|
||||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def delete_branch(
|
|
||||||
credentials: GithubCredentials, repo_url: str, branch: str
|
|
||||||
) -> str:
|
|
||||||
api = get_api(credentials)
|
|
||||||
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
|
||||||
await api.delete(ref_url)
|
|
||||||
return "Branch deleted successfully"
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
status = await self.delete_branch(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.branch,
|
|
||||||
)
|
|
||||||
yield "status", status
|
|
||||||
|
|
||||||
|
|
||||||
class GithubCreateFileBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
file_path: str = SchemaField(
|
|
||||||
description="Path where the file should be created",
|
|
||||||
placeholder="path/to/file.txt",
|
|
||||||
)
|
|
||||||
content: str = SchemaField(
|
|
||||||
description="Content to write to the file",
|
|
||||||
placeholder="File content here",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch where the file should be created",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
commit_message: str = SchemaField(
|
|
||||||
description="Message for the commit",
|
|
||||||
default="Create new file",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
url: str = SchemaField(description="URL of the created file")
|
|
||||||
sha: str = SchemaField(description="SHA of the commit")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the file creation failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
|
||||||
description="This block creates a new file in a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubCreateFileBlock.Input,
|
|
||||||
output_schema=GithubCreateFileBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"file_path": "test/file.txt",
|
|
||||||
"content": "Test content",
|
|
||||||
"branch": "main",
|
|
||||||
"commit_message": "Create test file",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
|
||||||
("sha", "abc123"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"create_file": lambda *args, **kwargs: (
|
|
||||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
|
||||||
"abc123",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create_file(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
file_path: str,
|
|
||||||
content: str,
|
|
||||||
branch: str,
|
|
||||||
commit_message: str,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
contents_url = repo_url + f"/contents/{file_path}"
|
|
||||||
content_base64 = base64.b64encode(content.encode()).decode()
|
|
||||||
data = {
|
|
||||||
"message": commit_message,
|
|
||||||
"content": content_base64,
|
|
||||||
"branch": branch,
|
|
||||||
}
|
|
||||||
response = await api.put(contents_url, json=data)
|
|
||||||
data = response.json()
|
|
||||||
return data["content"]["html_url"], data["commit"]["sha"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
url, sha = await self.create_file(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.file_path,
|
|
||||||
input_data.content,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.commit_message,
|
|
||||||
)
|
|
||||||
yield "url", url
|
|
||||||
yield "sha", sha
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubUpdateFileBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
file_path: str = SchemaField(
|
|
||||||
description="Path to the file to update",
|
|
||||||
placeholder="path/to/file.txt",
|
|
||||||
)
|
|
||||||
content: str = SchemaField(
|
|
||||||
description="New content for the file",
|
|
||||||
placeholder="Updated content here",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch containing the file",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
commit_message: str = SchemaField(
|
|
||||||
description="Message for the commit",
|
|
||||||
default="Update file",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
url: str = SchemaField(description="URL of the updated file")
|
|
||||||
sha: str = SchemaField(description="SHA of the commit")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
|
||||||
description="This block updates an existing file in a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubUpdateFileBlock.Input,
|
|
||||||
output_schema=GithubUpdateFileBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"file_path": "test/file.txt",
|
|
||||||
"content": "Updated content",
|
|
||||||
"branch": "main",
|
|
||||||
"commit_message": "Update test file",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
|
||||||
("sha", "def456"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"update_file": lambda *args, **kwargs: (
|
|
||||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
|
||||||
"def456",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_file(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
file_path: str,
|
|
||||||
content: str,
|
|
||||||
branch: str,
|
|
||||||
commit_message: str,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
contents_url = repo_url + f"/contents/{file_path}"
|
|
||||||
params = {"ref": branch}
|
|
||||||
response = await api.get(contents_url, params=params)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# Convert new content to base64
|
|
||||||
content_base64 = base64.b64encode(content.encode()).decode()
|
|
||||||
data = {
|
|
||||||
"message": commit_message,
|
|
||||||
"content": content_base64,
|
|
||||||
"sha": data["sha"],
|
|
||||||
"branch": branch,
|
|
||||||
}
|
|
||||||
response = await api.put(contents_url, json=data)
|
|
||||||
data = response.json()
|
|
||||||
return data["content"]["html_url"], data["commit"]["sha"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
url, sha = await self.update_file(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.file_path,
|
|
||||||
input_data.content,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.commit_message,
|
|
||||||
)
|
|
||||||
yield "url", url
|
|
||||||
yield "sha", sha
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubCreateRepositoryBlock(Block):
|
class GithubCreateRepositoryBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
@@ -1103,7 +449,7 @@ class GithubListStargazersBlock(Block):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
||||||
description="This block lists all users who have starred a specified GitHub repository.",
|
description="This block lists all users who have starred a specified GitHub repository.",
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
input_schema=GithubListStargazersBlock.Input,
|
input_schema=GithubListStargazersBlock.Input,
|
||||||
@@ -1172,3 +518,230 @@ class GithubListStargazersBlock(Block):
|
|||||||
yield "stargazers", stargazers
|
yield "stargazers", stargazers
|
||||||
for stargazer in stargazers:
|
for stargazer in stargazers:
|
||||||
yield "stargazer", stargazer
|
yield "stargazer", stargazer
|
||||||
|
|
||||||
|
|
||||||
|
class GithubGetRepositoryInfoBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
name: str = SchemaField(description="Repository name")
|
||||||
|
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
||||||
|
description: str = SchemaField(description="Repository description")
|
||||||
|
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
||||||
|
private: bool = SchemaField(description="Whether the repository is private")
|
||||||
|
html_url: str = SchemaField(description="Web URL of the repository")
|
||||||
|
clone_url: str = SchemaField(description="Git clone URL")
|
||||||
|
stars: int = SchemaField(description="Number of stars")
|
||||||
|
forks: int = SchemaField(description="Number of forks")
|
||||||
|
open_issues: int = SchemaField(description="Number of open issues")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if fetching repo info failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
||||||
|
description="This block retrieves metadata about a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubGetRepositoryInfoBlock.Input,
|
||||||
|
output_schema=GithubGetRepositoryInfoBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("name", "repo"),
|
||||||
|
("full_name", "owner/repo"),
|
||||||
|
("description", "A test repo"),
|
||||||
|
("default_branch", "main"),
|
||||||
|
("private", False),
|
||||||
|
("html_url", "https://github.com/owner/repo"),
|
||||||
|
("clone_url", "https://github.com/owner/repo.git"),
|
||||||
|
("stars", 42),
|
||||||
|
("forks", 5),
|
||||||
|
("open_issues", 3),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"get_repo_info": lambda *args, **kwargs: {
|
||||||
|
"name": "repo",
|
||||||
|
"full_name": "owner/repo",
|
||||||
|
"description": "A test repo",
|
||||||
|
"default_branch": "main",
|
||||||
|
"private": False,
|
||||||
|
"html_url": "https://github.com/owner/repo",
|
||||||
|
"clone_url": "https://github.com/owner/repo.git",
|
||||||
|
"stargazers_count": 42,
|
||||||
|
"forks_count": 5,
|
||||||
|
"open_issues_count": 3,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
||||||
|
api = get_api(credentials)
|
||||||
|
response = await api.get(repo_url)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
data = await self.get_repo_info(credentials, input_data.repo_url)
|
||||||
|
yield "name", data["name"]
|
||||||
|
yield "full_name", data["full_name"]
|
||||||
|
yield "description", data.get("description", "") or ""
|
||||||
|
yield "default_branch", data["default_branch"]
|
||||||
|
yield "private", data["private"]
|
||||||
|
yield "html_url", data["html_url"]
|
||||||
|
yield "clone_url", data["clone_url"]
|
||||||
|
yield "stars", data["stargazers_count"]
|
||||||
|
yield "forks", data["forks_count"]
|
||||||
|
yield "open_issues", data["open_issues_count"]
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubForkRepositoryBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository to fork",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
organization: str = SchemaField(
|
||||||
|
description="Organization to fork into (leave empty to fork to your account)",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
url: str = SchemaField(description="URL of the forked repository")
|
||||||
|
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
||||||
|
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
||||||
|
error: str = SchemaField(description="Error message if the fork failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
||||||
|
description="This block forks a GitHub repository to your account or an organization.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubForkRepositoryBlock.Input,
|
||||||
|
output_schema=GithubForkRepositoryBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"organization": "",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("url", "https://github.com/myuser/repo"),
|
||||||
|
("clone_url", "https://github.com/myuser/repo.git"),
|
||||||
|
("full_name", "myuser/repo"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"fork_repo": lambda *args, **kwargs: (
|
||||||
|
"https://github.com/myuser/repo",
|
||||||
|
"https://github.com/myuser/repo.git",
|
||||||
|
"myuser/repo",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def fork_repo(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
organization: str,
|
||||||
|
) -> tuple[str, str, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
forks_url = repo_url + "/forks"
|
||||||
|
data: dict[str, str] = {}
|
||||||
|
if organization:
|
||||||
|
data["organization"] = organization
|
||||||
|
response = await api.post(forks_url, json=data)
|
||||||
|
result = response.json()
|
||||||
|
return result["html_url"], result["clone_url"], result["full_name"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
url, clone_url, full_name = await self.fork_repo(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.organization,
|
||||||
|
)
|
||||||
|
yield "url", url
|
||||||
|
yield "clone_url", clone_url
|
||||||
|
yield "full_name", full_name
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubStarRepositoryBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository to star",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
status: str = SchemaField(description="Status of the star operation")
|
||||||
|
error: str = SchemaField(description="Error message if starring failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="bd700764-53e3-44dd-a969-d1854088458f",
|
||||||
|
description="This block stars a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubStarRepositoryBlock.Input,
|
||||||
|
output_schema=GithubStarRepositoryBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("status", "Repository starred successfully")],
|
||||||
|
test_mock={
|
||||||
|
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
||||||
|
api = get_api(credentials, convert_urls=False)
|
||||||
|
repo_path = github_repo_path(repo_url)
|
||||||
|
owner, repo = repo_path.split("/")
|
||||||
|
await api.put(
|
||||||
|
f"https://api.github.com/user/starred/{owner}/{repo}",
|
||||||
|
headers={"Content-Length": "0"},
|
||||||
|
)
|
||||||
|
return "Repository starred successfully"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
status = await self.star_repo(credentials, input_data.repo_url)
|
||||||
|
yield "status", status
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|||||||
452
autogpt_platform/backend/backend/blocks/github/repo_branches.py
Normal file
452
autogpt_platform/backend/backend/blocks/github/repo_branches.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
from ._api import get_api
|
||||||
|
from ._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
GithubCredentials,
|
||||||
|
GithubCredentialsField,
|
||||||
|
GithubCredentialsInput,
|
||||||
|
)
|
||||||
|
from ._utils import github_repo_path
|
||||||
|
|
||||||
|
|
||||||
|
class GithubListBranchesBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
per_page: int = SchemaField(
|
||||||
|
description="Number of branches to return per page (max 100)",
|
||||||
|
default=30,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
)
|
||||||
|
page: int = SchemaField(
|
||||||
|
description="Page number for pagination",
|
||||||
|
default=1,
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class BranchItem(TypedDict):
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
branch: BranchItem = SchemaField(
|
||||||
|
title="Branch",
|
||||||
|
description="Branches with their name and file tree browser URL",
|
||||||
|
)
|
||||||
|
branches: list[BranchItem] = SchemaField(
|
||||||
|
description="List of branches with their name and file tree browser URL"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if listing branches failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||||
|
description="This block lists all branches for a specified GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubListBranchesBlock.Input,
|
||||||
|
output_schema=GithubListBranchesBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"per_page": 30,
|
||||||
|
"page": 1,
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"branches",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"url": "https://github.com/owner/repo/tree/main",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"branch",
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"url": "https://github.com/owner/repo/tree/main",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"list_branches": lambda *args, **kwargs: [
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"url": "https://github.com/owner/repo/tree/main",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def list_branches(
|
||||||
|
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
||||||
|
) -> list[Output.BranchItem]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
branches_url = repo_url + "/branches"
|
||||||
|
response = await api.get(
|
||||||
|
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
repo_path = github_repo_path(repo_url)
|
||||||
|
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||||
|
{
|
||||||
|
"name": branch["name"],
|
||||||
|
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||||
|
}
|
||||||
|
for branch in data
|
||||||
|
]
|
||||||
|
return branches
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
branches = await self.list_branches(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.per_page,
|
||||||
|
input_data.page,
|
||||||
|
)
|
||||||
|
yield "branches", branches
|
||||||
|
for branch in branches:
|
||||||
|
yield "branch", branch
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubMakeBranchBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
new_branch: str = SchemaField(
|
||||||
|
description="Name of the new branch",
|
||||||
|
placeholder="new_branch_name",
|
||||||
|
)
|
||||||
|
source_branch: str = SchemaField(
|
||||||
|
description="Name of the source branch",
|
||||||
|
placeholder="source_branch_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
status: str = SchemaField(description="Status of the branch creation operation")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the branch creation failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||||
|
description="This block creates a new branch from a specified source branch.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubMakeBranchBlock.Input,
|
||||||
|
output_schema=GithubMakeBranchBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"new_branch": "new_branch_name",
|
||||||
|
"source_branch": "source_branch_name",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("status", "Branch created successfully")],
|
||||||
|
test_mock={
|
||||||
|
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_branch(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
new_branch: str,
|
||||||
|
source_branch: str,
|
||||||
|
) -> str:
|
||||||
|
api = get_api(credentials)
|
||||||
|
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
||||||
|
response = await api.get(ref_url)
|
||||||
|
data = response.json()
|
||||||
|
sha = data["object"]["sha"]
|
||||||
|
|
||||||
|
# Create the new branch
|
||||||
|
new_ref_url = repo_url + "/git/refs"
|
||||||
|
data = {
|
||||||
|
"ref": f"refs/heads/{new_branch}",
|
||||||
|
"sha": sha,
|
||||||
|
}
|
||||||
|
response = await api.post(new_ref_url, json=data)
|
||||||
|
return "Branch created successfully"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
status = await self.create_branch(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.new_branch,
|
||||||
|
input_data.source_branch,
|
||||||
|
)
|
||||||
|
yield "status", status
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubDeleteBranchBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Name of the branch to delete",
|
||||||
|
placeholder="branch_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the branch deletion failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||||
|
description="This block deletes a specified branch.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubDeleteBranchBlock.Input,
|
||||||
|
output_schema=GithubDeleteBranchBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"branch": "branch_name",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("status", "Branch deleted successfully")],
|
||||||
|
test_mock={
|
||||||
|
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||||
|
},
|
||||||
|
is_sensitive_action=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_branch(
|
||||||
|
credentials: GithubCredentials, repo_url: str, branch: str
|
||||||
|
) -> str:
|
||||||
|
api = get_api(credentials)
|
||||||
|
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
||||||
|
await api.delete(ref_url)
|
||||||
|
return "Branch deleted successfully"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
status = await self.delete_branch(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.branch,
|
||||||
|
)
|
||||||
|
yield "status", status
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubCompareBranchesBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
base: str = SchemaField(
|
||||||
|
description="Base branch or commit SHA",
|
||||||
|
placeholder="main",
|
||||||
|
)
|
||||||
|
head: str = SchemaField(
|
||||||
|
description="Head branch or commit SHA to compare against base",
|
||||||
|
placeholder="feature-branch",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class FileChange(TypedDict):
|
||||||
|
filename: str
|
||||||
|
status: str
|
||||||
|
additions: int
|
||||||
|
deletions: int
|
||||||
|
patch: str
|
||||||
|
|
||||||
|
status: str = SchemaField(
|
||||||
|
description="Comparison status: ahead, behind, diverged, or identical"
|
||||||
|
)
|
||||||
|
ahead_by: int = SchemaField(
|
||||||
|
description="Number of commits head is ahead of base"
|
||||||
|
)
|
||||||
|
behind_by: int = SchemaField(
|
||||||
|
description="Number of commits head is behind base"
|
||||||
|
)
|
||||||
|
total_commits: int = SchemaField(
|
||||||
|
description="Total number of commits in the comparison"
|
||||||
|
)
|
||||||
|
diff: str = SchemaField(description="Unified diff of all file changes")
|
||||||
|
file: FileChange = SchemaField(
|
||||||
|
title="Changed File", description="A changed file with its diff"
|
||||||
|
)
|
||||||
|
files: list[FileChange] = SchemaField(
|
||||||
|
description="List of changed files with their diffs"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if comparison failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
||||||
|
description="This block compares two branches or commits in a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubCompareBranchesBlock.Input,
|
||||||
|
output_schema=GithubCompareBranchesBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"base": "main",
|
||||||
|
"head": "feature",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("status", "ahead"),
|
||||||
|
("ahead_by", 2),
|
||||||
|
("behind_by", 0),
|
||||||
|
("total_commits", 2),
|
||||||
|
("diff", "+++ b/file.py\n+new line"),
|
||||||
|
(
|
||||||
|
"files",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"filename": "file.py",
|
||||||
|
"status": "modified",
|
||||||
|
"additions": 1,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": "+new line",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file",
|
||||||
|
{
|
||||||
|
"filename": "file.py",
|
||||||
|
"status": "modified",
|
||||||
|
"additions": 1,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": "+new line",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"compare_branches": lambda *args, **kwargs: {
|
||||||
|
"status": "ahead",
|
||||||
|
"ahead_by": 2,
|
||||||
|
"behind_by": 0,
|
||||||
|
"total_commits": 2,
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"filename": "file.py",
|
||||||
|
"status": "modified",
|
||||||
|
"additions": 1,
|
||||||
|
"deletions": 0,
|
||||||
|
"patch": "+new line",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def compare_branches(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
base: str,
|
||||||
|
head: str,
|
||||||
|
) -> dict:
|
||||||
|
api = get_api(credentials)
|
||||||
|
safe_base = quote(base, safe="")
|
||||||
|
safe_head = quote(head, safe="")
|
||||||
|
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
||||||
|
response = await api.get(compare_url)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
data = await self.compare_branches(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.base,
|
||||||
|
input_data.head,
|
||||||
|
)
|
||||||
|
yield "status", data["status"]
|
||||||
|
yield "ahead_by", data["ahead_by"]
|
||||||
|
yield "behind_by", data["behind_by"]
|
||||||
|
yield "total_commits", data["total_commits"]
|
||||||
|
|
||||||
|
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
||||||
|
GithubCompareBranchesBlock.Output.FileChange(
|
||||||
|
filename=f["filename"],
|
||||||
|
status=f["status"],
|
||||||
|
additions=f["additions"],
|
||||||
|
deletions=f["deletions"],
|
||||||
|
patch=f.get("patch", ""),
|
||||||
|
)
|
||||||
|
for f in data.get("files", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build unified diff
|
||||||
|
diff_parts = []
|
||||||
|
for f in data.get("files", []):
|
||||||
|
patch = f.get("patch", "")
|
||||||
|
if patch:
|
||||||
|
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
||||||
|
yield "diff", "\n".join(diff_parts)
|
||||||
|
|
||||||
|
yield "files", files
|
||||||
|
for file in files:
|
||||||
|
yield "file", file
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
720
autogpt_platform/backend/backend/blocks/github/repo_files.py
Normal file
720
autogpt_platform/backend/backend/blocks/github/repo_files.py
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
import base64
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from backend.blocks._base import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
from ._api import get_api
|
||||||
|
from ._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
GithubCredentials,
|
||||||
|
GithubCredentialsField,
|
||||||
|
GithubCredentialsInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReadFileBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
file_path: str = SchemaField(
|
||||||
|
description="Path to the file in the repository",
|
||||||
|
placeholder="path/to/file",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch to read from",
|
||||||
|
placeholder="branch_name",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
text_content: str = SchemaField(
|
||||||
|
description="Content of the file (decoded as UTF-8 text)"
|
||||||
|
)
|
||||||
|
raw_content: str = SchemaField(
|
||||||
|
description="Raw base64-encoded content of the file"
|
||||||
|
)
|
||||||
|
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||||
|
error: str = SchemaField(description="Error message if reading the file failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||||
|
description="This block reads the content of a specified file from a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubReadFileBlock.Input,
|
||||||
|
output_schema=GithubReadFileBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"file_path": "path/to/file",
|
||||||
|
"branch": "main",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("raw_content", "RmlsZSBjb250ZW50"),
|
||||||
|
("text_content", "File content"),
|
||||||
|
("size", 13),
|
||||||
|
],
|
||||||
|
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_file(
|
||||||
|
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||||
|
) -> tuple[str, int]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
content_url = (
|
||||||
|
repo_url
|
||||||
|
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
||||||
|
)
|
||||||
|
response = await api.get(content_url)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
# Multiple entries of different types exist at this path
|
||||||
|
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||||
|
raise TypeError("Not a file")
|
||||||
|
data = file
|
||||||
|
|
||||||
|
if data["type"] != "file":
|
||||||
|
raise TypeError("Not a file")
|
||||||
|
|
||||||
|
return data["content"], data["size"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
content, size = await self.read_file(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.file_path,
|
||||||
|
input_data.branch,
|
||||||
|
)
|
||||||
|
yield "raw_content", content
|
||||||
|
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||||
|
yield "size", size
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReadFolderBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
folder_path: str = SchemaField(
|
||||||
|
description="Path to the folder in the repository",
|
||||||
|
placeholder="path/to/folder",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch name to read from (defaults to main)",
|
||||||
|
placeholder="branch_name",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class DirEntry(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
|
||||||
|
class FileEntry(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size: int
|
||||||
|
|
||||||
|
file: FileEntry = SchemaField(description="Files in the folder")
|
||||||
|
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if reading the folder failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||||
|
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubReadFolderBlock.Input,
|
||||||
|
output_schema=GithubReadFolderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"folder_path": "path/to/folder",
|
||||||
|
"branch": "main",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"file",
|
||||||
|
{
|
||||||
|
"name": "file1.txt",
|
||||||
|
"path": "path/to/folder/file1.txt",
|
||||||
|
"size": 1337,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"read_folder": lambda *args, **kwargs: (
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "file1.txt",
|
||||||
|
"path": "path/to/folder/file1.txt",
|
||||||
|
"size": 1337,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_folder(
|
||||||
|
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||||
|
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
contents_url = (
|
||||||
|
repo_url
|
||||||
|
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
||||||
|
)
|
||||||
|
response = await api.get(contents_url)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise TypeError("Not a folder")
|
||||||
|
|
||||||
|
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||||
|
GithubReadFolderBlock.Output.FileEntry(
|
||||||
|
name=entry["name"],
|
||||||
|
path=entry["path"],
|
||||||
|
size=entry["size"],
|
||||||
|
)
|
||||||
|
for entry in data
|
||||||
|
if entry["type"] == "file"
|
||||||
|
]
|
||||||
|
|
||||||
|
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||||
|
GithubReadFolderBlock.Output.DirEntry(
|
||||||
|
name=entry["name"],
|
||||||
|
path=entry["path"],
|
||||||
|
)
|
||||||
|
for entry in data
|
||||||
|
if entry["type"] == "dir"
|
||||||
|
]
|
||||||
|
|
||||||
|
return files, dirs
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
files, dirs = await self.read_folder(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.folder_path.lstrip("/"),
|
||||||
|
input_data.branch,
|
||||||
|
)
|
||||||
|
for file in files:
|
||||||
|
yield "file", file
|
||||||
|
for dir in dirs:
|
||||||
|
yield "dir", dir
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubCreateFileBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
file_path: str = SchemaField(
|
||||||
|
description="Path where the file should be created",
|
||||||
|
placeholder="path/to/file.txt",
|
||||||
|
)
|
||||||
|
content: str = SchemaField(
|
||||||
|
description="Content to write to the file",
|
||||||
|
placeholder="File content here",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch where the file should be created",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
commit_message: str = SchemaField(
|
||||||
|
description="Message for the commit",
|
||||||
|
default="Create new file",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
url: str = SchemaField(description="URL of the created file")
|
||||||
|
sha: str = SchemaField(description="SHA of the commit")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the file creation failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||||
|
description="This block creates a new file in a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubCreateFileBlock.Input,
|
||||||
|
output_schema=GithubCreateFileBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"file_path": "test/file.txt",
|
||||||
|
"content": "Test content",
|
||||||
|
"branch": "main",
|
||||||
|
"commit_message": "Create test file",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||||
|
("sha", "abc123"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"create_file": lambda *args, **kwargs: (
|
||||||
|
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||||
|
"abc123",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_file(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
file_path: str,
|
||||||
|
content: str,
|
||||||
|
branch: str,
|
||||||
|
commit_message: str,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||||
|
content_base64 = base64.b64encode(content.encode()).decode()
|
||||||
|
data = {
|
||||||
|
"message": commit_message,
|
||||||
|
"content": content_base64,
|
||||||
|
"branch": branch,
|
||||||
|
}
|
||||||
|
response = await api.put(contents_url, json=data)
|
||||||
|
data = response.json()
|
||||||
|
return data["content"]["html_url"], data["commit"]["sha"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
url, sha = await self.create_file(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.file_path,
|
||||||
|
input_data.content,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.commit_message,
|
||||||
|
)
|
||||||
|
yield "url", url
|
||||||
|
yield "sha", sha
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubUpdateFileBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
file_path: str = SchemaField(
|
||||||
|
description="Path to the file to update",
|
||||||
|
placeholder="path/to/file.txt",
|
||||||
|
)
|
||||||
|
content: str = SchemaField(
|
||||||
|
description="New content for the file",
|
||||||
|
placeholder="Updated content here",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch containing the file",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
commit_message: str = SchemaField(
|
||||||
|
description="Message for the commit",
|
||||||
|
default="Update file",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
url: str = SchemaField(description="URL of the updated file")
|
||||||
|
sha: str = SchemaField(description="SHA of the commit")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||||
|
description="This block updates an existing file in a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubUpdateFileBlock.Input,
|
||||||
|
output_schema=GithubUpdateFileBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"file_path": "test/file.txt",
|
||||||
|
"content": "Updated content",
|
||||||
|
"branch": "main",
|
||||||
|
"commit_message": "Update test file",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||||
|
("sha", "def456"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"update_file": lambda *args, **kwargs: (
|
||||||
|
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||||
|
"def456",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_file(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
file_path: str,
|
||||||
|
content: str,
|
||||||
|
branch: str,
|
||||||
|
commit_message: str,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
||||||
|
params = {"ref": branch}
|
||||||
|
response = await api.get(contents_url, params=params)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Convert new content to base64
|
||||||
|
content_base64 = base64.b64encode(content.encode()).decode()
|
||||||
|
data = {
|
||||||
|
"message": commit_message,
|
||||||
|
"content": content_base64,
|
||||||
|
"sha": data["sha"],
|
||||||
|
"branch": branch,
|
||||||
|
}
|
||||||
|
response = await api.put(contents_url, json=data)
|
||||||
|
data = response.json()
|
||||||
|
return data["content"]["html_url"], data["commit"]["sha"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
url, sha = await self.update_file(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.file_path,
|
||||||
|
input_data.content,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.commit_message,
|
||||||
|
)
|
||||||
|
yield "url", url
|
||||||
|
yield "sha", sha
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubSearchCodeBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
query: str = SchemaField(
|
||||||
|
description="Search query (GitHub code search syntax)",
|
||||||
|
placeholder="className language:python",
|
||||||
|
)
|
||||||
|
repo: str = SchemaField(
|
||||||
|
description="Restrict search to a repository (owner/repo format, optional)",
|
||||||
|
default="",
|
||||||
|
placeholder="owner/repo",
|
||||||
|
)
|
||||||
|
per_page: int = SchemaField(
|
||||||
|
description="Number of results to return (max 100)",
|
||||||
|
default=30,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class SearchResult(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
repository: str
|
||||||
|
url: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
result: SearchResult = SchemaField(
|
||||||
|
title="Result", description="A code search result"
|
||||||
|
)
|
||||||
|
results: list[SearchResult] = SchemaField(
|
||||||
|
description="List of code search results"
|
||||||
|
)
|
||||||
|
total_count: int = SchemaField(description="Total number of matching results")
|
||||||
|
error: str = SchemaField(description="Error message if search failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
||||||
|
description="This block searches for code in GitHub repositories.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubSearchCodeBlock.Input,
|
||||||
|
output_schema=GithubSearchCodeBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"query": "addClass",
|
||||||
|
"repo": "owner/repo",
|
||||||
|
"per_page": 30,
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("total_count", 1),
|
||||||
|
(
|
||||||
|
"results",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "file.py",
|
||||||
|
"path": "src/file.py",
|
||||||
|
"repository": "owner/repo",
|
||||||
|
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||||
|
"score": 1.0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"result",
|
||||||
|
{
|
||||||
|
"name": "file.py",
|
||||||
|
"path": "src/file.py",
|
||||||
|
"repository": "owner/repo",
|
||||||
|
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||||
|
"score": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"search_code": lambda *args, **kwargs: (
|
||||||
|
1,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "file.py",
|
||||||
|
"path": "src/file.py",
|
||||||
|
"repository": "owner/repo",
|
||||||
|
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
||||||
|
"score": 1.0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def search_code(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
query: str,
|
||||||
|
repo: str,
|
||||||
|
per_page: int,
|
||||||
|
) -> tuple[int, list[Output.SearchResult]]:
|
||||||
|
api = get_api(credentials, convert_urls=False)
|
||||||
|
full_query = f"{query} repo:{repo}" if repo else query
|
||||||
|
params = {"q": full_query, "per_page": str(per_page)}
|
||||||
|
response = await api.get("https://api.github.com/search/code", params=params)
|
||||||
|
data = response.json()
|
||||||
|
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
||||||
|
GithubSearchCodeBlock.Output.SearchResult(
|
||||||
|
name=item["name"],
|
||||||
|
path=item["path"],
|
||||||
|
repository=item["repository"]["full_name"],
|
||||||
|
url=item["html_url"],
|
||||||
|
score=item["score"],
|
||||||
|
)
|
||||||
|
for item in data["items"]
|
||||||
|
]
|
||||||
|
return data["total_count"], results
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
total_count, results = await self.search_code(
|
||||||
|
credentials,
|
||||||
|
input_data.query,
|
||||||
|
input_data.repo,
|
||||||
|
input_data.per_page,
|
||||||
|
)
|
||||||
|
yield "total_count", total_count
|
||||||
|
yield "results", results
|
||||||
|
for result in results:
|
||||||
|
yield "result", result
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubGetRepositoryTreeBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch name to get the tree from",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
recursive: bool = SchemaField(
|
||||||
|
description="Whether to recursively list the entire tree",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class TreeEntry(TypedDict):
|
||||||
|
path: str
|
||||||
|
type: str
|
||||||
|
size: int
|
||||||
|
sha: str
|
||||||
|
|
||||||
|
entry: TreeEntry = SchemaField(
|
||||||
|
title="Tree Entry", description="A file or directory in the tree"
|
||||||
|
)
|
||||||
|
entries: list[TreeEntry] = SchemaField(
|
||||||
|
description="List of all files and directories in the tree"
|
||||||
|
)
|
||||||
|
truncated: bool = SchemaField(
|
||||||
|
description="Whether the tree was truncated due to size"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if getting tree failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
||||||
|
description="This block lists the entire file tree of a GitHub repository recursively.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubGetRepositoryTreeBlock.Input,
|
||||||
|
output_schema=GithubGetRepositoryTreeBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"branch": "main",
|
||||||
|
"recursive": True,
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("truncated", False),
|
||||||
|
(
|
||||||
|
"entries",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"path": "src/main.py",
|
||||||
|
"type": "blob",
|
||||||
|
"size": 1234,
|
||||||
|
"sha": "abc123",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"entry",
|
||||||
|
{
|
||||||
|
"path": "src/main.py",
|
||||||
|
"type": "blob",
|
||||||
|
"size": 1234,
|
||||||
|
"sha": "abc123",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"get_tree": lambda *args, **kwargs: (
|
||||||
|
False,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"path": "src/main.py",
|
||||||
|
"type": "blob",
|
||||||
|
"size": 1234,
|
||||||
|
"sha": "abc123",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_tree(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
branch: str,
|
||||||
|
recursive: bool,
|
||||||
|
) -> tuple[bool, list[Output.TreeEntry]]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
||||||
|
params = {"recursive": "1"} if recursive else {}
|
||||||
|
response = await api.get(tree_url, params=params)
|
||||||
|
data = response.json()
|
||||||
|
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
||||||
|
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
||||||
|
path=item["path"],
|
||||||
|
type=item["type"],
|
||||||
|
size=item.get("size", 0),
|
||||||
|
sha=item["sha"],
|
||||||
|
)
|
||||||
|
for item in data["tree"]
|
||||||
|
]
|
||||||
|
return data.get("truncated", False), entries
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
truncated, entries = await self.get_tree(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.recursive,
|
||||||
|
)
|
||||||
|
yield "truncated", truncated
|
||||||
|
yield "entries", entries
|
||||||
|
for entry in entries:
|
||||||
|
yield "entry", entry
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||||
|
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
||||||
|
from backend.blocks.github.pull_requests import (
|
||||||
|
GithubMergePullRequestBlock,
|
||||||
|
prepare_pr_api_url,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
|
||||||
|
# ── prepare_pr_api_url tests ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreparePrApiUrl:
|
||||||
|
def test_https_scheme_preserved(self):
|
||||||
|
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
||||||
|
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
||||||
|
|
||||||
|
def test_http_scheme_preserved(self):
|
||||||
|
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
||||||
|
assert result == "http://github.com/owner/repo/pulls/1/files"
|
||||||
|
|
||||||
|
def test_no_scheme_defaults_to_https(self):
|
||||||
|
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
||||||
|
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
||||||
|
|
||||||
|
def test_reviewers_path(self):
|
||||||
|
result = prepare_pr_api_url(
|
||||||
|
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
||||||
|
)
|
||||||
|
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
||||||
|
|
||||||
|
def test_invalid_url_returned_as_is(self):
|
||||||
|
url = "https://example.com/not-a-pr"
|
||||||
|
assert prepare_pr_api_url(url, "merge") == url
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert prepare_pr_api_url("", "merge") == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Error-path block tests ──
|
||||||
|
# When a block's run() yields ("error", msg), _execute() converts it to a
|
||||||
|
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
||||||
|
# which returns early on empty test_output).
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_block(block, mocks: dict):
|
||||||
|
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
||||||
|
for name, mock_fn in mocks.items():
|
||||||
|
original = getattr(block, name)
|
||||||
|
if inspect.iscoroutinefunction(original):
|
||||||
|
|
||||||
|
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
||||||
|
return _fn(*args, **kwargs)
|
||||||
|
|
||||||
|
setattr(block, name, async_mock)
|
||||||
|
else:
|
||||||
|
setattr(block, name, mock_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _raise(exc: Exception):
|
||||||
|
"""Helper that returns a callable which raises the given exception."""
|
||||||
|
|
||||||
|
def _raiser(*args, **kwargs):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return _raiser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_pr_error_path():
|
||||||
|
block = GithubMergePullRequestBlock()
|
||||||
|
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
||||||
|
input_data = {
|
||||||
|
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||||
|
"merge_method": "squash",
|
||||||
|
"commit_title": "",
|
||||||
|
"commit_message": "",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
}
|
||||||
|
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
||||||
|
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_file_commit_error_path():
|
||||||
|
block = GithubMultiFileCommitBlock()
|
||||||
|
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
||||||
|
input_data = {
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"branch": "feature",
|
||||||
|
"commit_message": "test",
|
||||||
|
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
}
|
||||||
|
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||||
|
async for _ in block.execute(
|
||||||
|
input_data,
|
||||||
|
credentials=TEST_CREDENTIALS,
|
||||||
|
execution_context=ExecutionContext(),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── FileOperation enum tests ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileOperation:
|
||||||
|
def test_upsert_value(self):
|
||||||
|
assert FileOperation.UPSERT == "upsert"
|
||||||
|
|
||||||
|
def test_delete_value(self):
|
||||||
|
assert FileOperation.DELETE == "delete"
|
||||||
|
|
||||||
|
def test_invalid_value_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FileOperation("create")
|
||||||
|
|
||||||
|
def test_invalid_value_raises_typo(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FileOperation("upser")
|
||||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
|||||||
h.ignore_links = False
|
h.ignore_links = False
|
||||||
h.ignore_images = True
|
h.ignore_images = True
|
||||||
return h.handle(html_content)
|
return h.handle(html_content)
|
||||||
except ImportError:
|
except Exception:
|
||||||
# Fallback: return raw HTML if html2text is not available
|
# Keep extraction resilient if html2text is unavailable or fails.
|
||||||
return html_content
|
return html_content
|
||||||
|
|
||||||
# Handle content stored as attachment
|
# Handle content stored as attachment
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class HITLReviewHelper:
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
|
is_graph_execution: bool = True,
|
||||||
) -> Optional[ReviewResult]:
|
) -> Optional[ReviewResult]:
|
||||||
"""
|
"""
|
||||||
Handle a review request for a block that requires human review.
|
Handle a review request for a block that requires human review.
|
||||||
@@ -143,10 +144,11 @@ class HITLReviewHelper:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||||
)
|
)
|
||||||
await HITLReviewHelper.update_node_execution_status(
|
if is_graph_execution:
|
||||||
exec_id=node_exec_id,
|
await HITLReviewHelper.update_node_execution_status(
|
||||||
status=ExecutionStatus.REVIEW,
|
exec_id=node_exec_id,
|
||||||
)
|
status=ExecutionStatus.REVIEW,
|
||||||
|
)
|
||||||
return None # Signal that execution should pause
|
return None # Signal that execution should pause
|
||||||
|
|
||||||
# Mark review as processed if not already done
|
# Mark review as processed if not already done
|
||||||
@@ -168,6 +170,7 @@ class HITLReviewHelper:
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
|
is_graph_execution: bool = True,
|
||||||
) -> Optional[ReviewDecision]:
|
) -> Optional[ReviewDecision]:
|
||||||
"""
|
"""
|
||||||
Handle a review request and return the decision in a single call.
|
Handle a review request and return the decision in a single call.
|
||||||
@@ -197,6 +200,7 @@ class HITLReviewHelper:
|
|||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
editable=editable,
|
editable=editable,
|
||||||
|
is_graph_execution=is_graph_execution,
|
||||||
)
|
)
|
||||||
|
|
||||||
if review_result is None:
|
if review_result is None:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
|||||||
from backend.blocks.search import GetRequest
|
from backend.blocks.search import GetRequest
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
||||||
|
|
||||||
|
|
||||||
class SearchTheWebBlock(Block, GetRequest):
|
class SearchTheWebBlock(Block, GetRequest):
|
||||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if input_data.raw_content:
|
if input_data.raw_content:
|
||||||
try:
|
try:
|
||||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
parsed_url, _, _ = await validate_url_host(input_data.url)
|
||||||
url = parsed_url.geturl()
|
url = parsed_url.geturl()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "error", f"Invalid URL: {e}"
|
yield "error", f"Invalid URL: {e}"
|
||||||
|
|||||||
@@ -31,10 +31,14 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
from backend.util.clients import OPENROUTER_BASE_URL
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.prompt import compress_context, estimate_token_count
|
from backend.util.prompt import compress_context, estimate_token_count
|
||||||
|
from backend.util.request import validate_url_host
|
||||||
|
from backend.util.settings import Settings
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||||
fmt = TextFormatter(autoescape=False)
|
fmt = TextFormatter(autoescape=False)
|
||||||
|
|
||||||
@@ -136,19 +140,31 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
# OpenRouter models
|
# OpenRouter models
|
||||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
||||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
||||||
|
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
||||||
|
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
||||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||||
|
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
||||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||||
|
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
||||||
|
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
||||||
|
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
||||||
|
CODESTRAL = "mistralai/codestral-2508"
|
||||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||||
|
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
||||||
|
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
||||||
|
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
||||||
|
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
||||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||||
|
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
||||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||||
@@ -156,9 +172,11 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||||
|
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
||||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||||
|
GROK_3 = "x-ai/grok-3"
|
||||||
GROK_4 = "x-ai/grok-4"
|
GROK_4 = "x-ai/grok-4"
|
||||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||||
@@ -336,17 +354,41 @@ MODEL_METADATA = {
|
|||||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||||
),
|
),
|
||||||
# https://openrouter.ai/models
|
# https://openrouter.ai/models
|
||||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
||||||
"open_router",
|
"open_router",
|
||||||
1050000,
|
1048576,
|
||||||
8192,
|
65536,
|
||||||
"Gemini 2.5 Pro Preview 03.25",
|
"Gemini 2.5 Pro Preview 03.25",
|
||||||
"OpenRouter",
|
"OpenRouter",
|
||||||
"Google",
|
"Google",
|
||||||
2,
|
2,
|
||||||
),
|
),
|
||||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
"open_router",
|
||||||
|
1048576,
|
||||||
|
65536,
|
||||||
|
"Gemini 2.5 Pro",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
1048576,
|
||||||
|
65536,
|
||||||
|
"Gemini 3.1 Pro Preview",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
1048576,
|
||||||
|
65536,
|
||||||
|
"Gemini 3 Flash Preview",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
1,
|
||||||
),
|
),
|
||||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||||
@@ -354,6 +396,15 @@ MODEL_METADATA = {
|
|||||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||||
),
|
),
|
||||||
|
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
1048576,
|
||||||
|
65536,
|
||||||
|
"Gemini 3.1 Flash Lite Preview",
|
||||||
|
"OpenRouter",
|
||||||
|
"Google",
|
||||||
|
1,
|
||||||
|
),
|
||||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||||
"open_router",
|
"open_router",
|
||||||
1048576,
|
1048576,
|
||||||
@@ -375,12 +426,78 @@ MODEL_METADATA = {
|
|||||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||||
),
|
),
|
||||||
|
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
262144,
|
||||||
|
None,
|
||||||
|
"Mistral Large 3 2512",
|
||||||
|
"OpenRouter",
|
||||||
|
"Mistral AI",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
131072,
|
||||||
|
None,
|
||||||
|
"Mistral Medium 3.1",
|
||||||
|
"OpenRouter",
|
||||||
|
"Mistral AI",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
131072,
|
||||||
|
131072,
|
||||||
|
"Mistral Small 3.2 24B",
|
||||||
|
"OpenRouter",
|
||||||
|
"Mistral AI",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
LlmModel.CODESTRAL: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
256000,
|
||||||
|
None,
|
||||||
|
"Codestral 2508",
|
||||||
|
"OpenRouter",
|
||||||
|
"Mistral AI",
|
||||||
|
1,
|
||||||
|
),
|
||||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||||
),
|
),
|
||||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||||
),
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
||||||
|
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
128000,
|
||||||
|
8192,
|
||||||
|
"Command A Translate 08.2025",
|
||||||
|
"OpenRouter",
|
||||||
|
"Cohere",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
256000,
|
||||||
|
32768,
|
||||||
|
"Command A Reasoning 08.2025",
|
||||||
|
"OpenRouter",
|
||||||
|
"Cohere",
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
128000,
|
||||||
|
8192,
|
||||||
|
"Command A Vision 07.2025",
|
||||||
|
"OpenRouter",
|
||||||
|
"Cohere",
|
||||||
|
2,
|
||||||
|
),
|
||||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||||
),
|
),
|
||||||
@@ -393,6 +510,15 @@ MODEL_METADATA = {
|
|||||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||||
),
|
),
|
||||||
|
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
128000,
|
||||||
|
8000,
|
||||||
|
"Sonar Reasoning Pro",
|
||||||
|
"OpenRouter",
|
||||||
|
"Perplexity",
|
||||||
|
2,
|
||||||
|
),
|
||||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||||
"open_router",
|
"open_router",
|
||||||
128000,
|
128000,
|
||||||
@@ -438,6 +564,9 @@ MODEL_METADATA = {
|
|||||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||||
),
|
),
|
||||||
|
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
||||||
|
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
||||||
|
),
|
||||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||||
),
|
),
|
||||||
@@ -447,6 +576,15 @@ MODEL_METADATA = {
|
|||||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||||
),
|
),
|
||||||
|
LlmModel.GROK_3: ModelMetadata(
|
||||||
|
"open_router",
|
||||||
|
131072,
|
||||||
|
131072,
|
||||||
|
"Grok 3",
|
||||||
|
"OpenRouter",
|
||||||
|
"xAI",
|
||||||
|
2,
|
||||||
|
),
|
||||||
LlmModel.GROK_4: ModelMetadata(
|
LlmModel.GROK_4: ModelMetadata(
|
||||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||||
),
|
),
|
||||||
@@ -804,6 +942,11 @@ async def llm_call(
|
|||||||
if tools:
|
if tools:
|
||||||
raise ValueError("Ollama does not support tools.")
|
raise ValueError("Ollama does not support tools.")
|
||||||
|
|
||||||
|
# Validate user-provided Ollama host to prevent SSRF etc.
|
||||||
|
await validate_url_host(
|
||||||
|
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
||||||
|
)
|
||||||
|
|
||||||
client = ollama.AsyncClient(host=ollama_host)
|
client = ollama.AsyncClient(host=ollama_host)
|
||||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||||
@@ -825,7 +968,7 @@ async def llm_call(
|
|||||||
elif provider == "open_router":
|
elif provider == "open_router":
|
||||||
tools_param = tools if tools else openai.NOT_GIVEN
|
tools_param = tools if tools else openai.NOT_GIVEN
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url=OPENROUTER_BASE_URL,
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from enum import Enum
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr, field_validator
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
@@ -13,6 +13,7 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.block import BlockInput
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -21,6 +22,7 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.clients import OPENROUTER_BASE_URL
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||||
@@ -34,6 +36,20 @@ class PerplexityModel(str, Enum):
|
|||||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
||||||
|
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
||||||
|
if isinstance(value, PerplexityModel):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return PerplexityModel(value)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid PerplexityModel '{value}', "
|
||||||
|
f"falling back to {PerplexityModel.SONAR.value}"
|
||||||
|
)
|
||||||
|
return PerplexityModel.SONAR
|
||||||
|
|
||||||
|
|
||||||
PerplexityCredentials = CredentialsMetaInput[
|
PerplexityCredentials = CredentialsMetaInput[
|
||||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||||
]
|
]
|
||||||
@@ -72,6 +88,25 @@ class PerplexityBlock(Block):
|
|||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||||
|
|
||||||
|
@field_validator("model", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
||||||
|
"""Fall back to SONAR if the model value is not a valid
|
||||||
|
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
||||||
|
generator)."""
|
||||||
|
return _sanitize_perplexity_model(v)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, data: BlockInput) -> str | None:
|
||||||
|
"""Sanitize the model field before JSON schema validation so that
|
||||||
|
invalid values are replaced with the default instead of raising a
|
||||||
|
BlockInputError."""
|
||||||
|
model_value = data.get("model")
|
||||||
|
if model_value is not None:
|
||||||
|
data["model"] = _sanitize_perplexity_model(model_value).value
|
||||||
|
return super().validate_data(data)
|
||||||
|
|
||||||
system_prompt: str = SchemaField(
|
system_prompt: str = SchemaField(
|
||||||
title="System Prompt",
|
title="System Prompt",
|
||||||
default="",
|
default="",
|
||||||
@@ -136,7 +171,7 @@ class PerplexityBlock(Block):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url=OPENROUTER_BASE_URL,
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2232,6 +2232,7 @@ class DeleteRedditPostBlock(Block):
|
|||||||
("post_id", "abc123"),
|
("post_id", "abc123"),
|
||||||
],
|
],
|
||||||
test_mock={"delete_post": lambda creds, post_id: True},
|
test_mock={"delete_post": lambda creds, post_id: True},
|
||||||
|
is_sensitive_action=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -2290,6 +2291,7 @@ class DeleteRedditCommentBlock(Block):
|
|||||||
("comment_id", "xyz789"),
|
("comment_id", "xyz789"),
|
||||||
],
|
],
|
||||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||||
|
is_sensitive_action=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
|||||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||||
},
|
},
|
||||||
|
is_sensitive_action=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
"""Unit tests for PerplexityBlock model fallback behavior."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.perplexity import (
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
PerplexityBlock,
|
||||||
|
PerplexityModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_input(**overrides) -> dict:
|
||||||
|
defaults = {
|
||||||
|
"prompt": "test query",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return defaults
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerplexityModelFallback:
|
||||||
|
"""Tests for fallback_invalid_model field_validator."""
|
||||||
|
|
||||||
|
def test_invalid_model_falls_back_to_sonar(self):
|
||||||
|
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
||||||
|
assert inp.model == PerplexityModel.SONAR
|
||||||
|
|
||||||
|
def test_another_invalid_model_falls_back_to_sonar(self):
|
||||||
|
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
||||||
|
assert inp.model == PerplexityModel.SONAR
|
||||||
|
|
||||||
|
def test_valid_model_string_is_kept(self):
|
||||||
|
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
||||||
|
assert inp.model == PerplexityModel.SONAR_PRO
|
||||||
|
|
||||||
|
def test_valid_enum_value_is_kept(self):
|
||||||
|
inp = PerplexityBlock.Input(
|
||||||
|
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
||||||
|
)
|
||||||
|
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
||||||
|
|
||||||
|
def test_default_model_when_omitted(self):
|
||||||
|
inp = PerplexityBlock.Input(**_make_input())
|
||||||
|
assert inp.model == PerplexityModel.SONAR
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_value",
|
||||||
|
[
|
||||||
|
"perplexity/sonar",
|
||||||
|
"perplexity/sonar-pro",
|
||||||
|
"perplexity/sonar-deep-research",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_all_valid_models_accepted(self, model_value: str):
|
||||||
|
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
||||||
|
assert inp.model.value == model_value
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerplexityValidateData:
|
||||||
|
"""Tests for validate_data which runs during block execution (before
|
||||||
|
Pydantic instantiation). Invalid models must be sanitized here so
|
||||||
|
JSON schema validation does not reject them."""
|
||||||
|
|
||||||
|
def test_invalid_model_sanitized_before_schema_validation(self):
|
||||||
|
data = _make_input(model="gpt-5.2-2025-12-11")
|
||||||
|
error = PerplexityBlock.Input.validate_data(data)
|
||||||
|
assert error is None
|
||||||
|
assert data["model"] == PerplexityModel.SONAR.value
|
||||||
|
|
||||||
|
def test_valid_model_unchanged_by_validate_data(self):
|
||||||
|
data = _make_input(model="perplexity/sonar-pro")
|
||||||
|
error = PerplexityBlock.Input.validate_data(data)
|
||||||
|
assert error is None
|
||||||
|
assert data["model"] == "perplexity/sonar-pro"
|
||||||
|
|
||||||
|
def test_missing_model_uses_default(self):
|
||||||
|
data = _make_input() # no model key
|
||||||
|
error = PerplexityBlock.Input.validate_data(data)
|
||||||
|
assert error is None
|
||||||
|
inp = PerplexityBlock.Input(**data)
|
||||||
|
assert inp.model == PerplexityModel.SONAR
|
||||||
@@ -40,7 +40,7 @@ from backend.copilot.response_model import (
|
|||||||
from backend.copilot.service import (
|
from backend.copilot.service import (
|
||||||
_build_system_prompt,
|
_build_system_prompt,
|
||||||
_generate_session_title,
|
_generate_session_title,
|
||||||
client,
|
_get_openai_client,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
from backend.copilot.tools import execute_tool, get_available_tools
|
from backend.copilot.tools import execute_tool, get_available_tools
|
||||||
@@ -89,7 +89,7 @@ async def _compress_session_messages(
|
|||||||
result = await compress_context(
|
result = await compress_context(
|
||||||
messages=messages_dict,
|
messages=messages_dict,
|
||||||
model=config.model,
|
model=config.model,
|
||||||
client=client,
|
client=_get_openai_client(),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||||
@@ -235,7 +235,7 @@ async def stream_chat_completion_baseline(
|
|||||||
)
|
)
|
||||||
if tools:
|
if tools:
|
||||||
create_kwargs["tools"] = tools
|
create_kwargs["tools"] = tools
|
||||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||||
|
|
||||||
# Accumulate streamed response (text + tool calls)
|
# Accumulate streamed response (text + tool calls)
|
||||||
round_text = ""
|
round_text = ""
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""Configuration management for chat system."""
|
"""Configuration management for chat system."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
from backend.util.clients import OPENROUTER_BASE_URL
|
||||||
|
|
||||||
|
|
||||||
class ChatConfig(BaseSettings):
|
class ChatConfig(BaseSettings):
|
||||||
"""Configuration for the chat system."""
|
"""Configuration for the chat system."""
|
||||||
@@ -19,7 +22,7 @@ class ChatConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||||
base_url: str | None = Field(
|
base_url: str | None = Field(
|
||||||
default="https://openrouter.ai/api/v1",
|
default=OPENROUTER_BASE_URL,
|
||||||
description="Base URL for API (e.g., for OpenRouter)",
|
description="Base URL for API (e.g., for OpenRouter)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -112,9 +115,37 @@ class ChatConfig(BaseSettings):
|
|||||||
description="E2B sandbox template to use for copilot sessions.",
|
description="E2B sandbox template to use for copilot sessions.",
|
||||||
)
|
)
|
||||||
e2b_sandbox_timeout: int = Field(
|
e2b_sandbox_timeout: int = Field(
|
||||||
default=43200, # 12 hours — same as session_ttl
|
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||||
description="E2B sandbox keepalive timeout in seconds.",
|
description="E2B sandbox running-time timeout (seconds). "
|
||||||
|
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||||
|
"mechanism; this is the safety net.",
|
||||||
)
|
)
|
||||||
|
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
|
||||||
|
default="pause",
|
||||||
|
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def e2b_active(self) -> bool:
|
||||||
|
"""True when E2B is enabled and the API key is present.
|
||||||
|
|
||||||
|
Single source of truth for "should we use E2B right now?".
|
||||||
|
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
|
||||||
|
separately at call sites.
|
||||||
|
"""
|
||||||
|
return self.use_e2b_sandbox and bool(self.e2b_api_key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_e2b_api_key(self) -> str | None:
|
||||||
|
"""Return the E2B API key when E2B is enabled and configured, else None.
|
||||||
|
|
||||||
|
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
|
||||||
|
Use in callers::
|
||||||
|
|
||||||
|
if api_key := config.active_e2b_api_key:
|
||||||
|
# E2B is active; api_key is narrowed to str
|
||||||
|
"""
|
||||||
|
return self.e2b_api_key if self.e2b_active else None
|
||||||
|
|
||||||
@field_validator("use_e2b_sandbox", mode="before")
|
@field_validator("use_e2b_sandbox", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -164,7 +195,7 @@ class ChatConfig(BaseSettings):
|
|||||||
if not v:
|
if not v:
|
||||||
v = os.getenv("OPENAI_BASE_URL")
|
v = os.getenv("OPENAI_BASE_URL")
|
||||||
if not v:
|
if not v:
|
||||||
v = "https://openrouter.ai/api/v1"
|
v = OPENROUTER_BASE_URL
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("use_claude_agent_sdk", mode="before")
|
@field_validator("use_claude_agent_sdk", mode="before")
|
||||||
|
|||||||
38
autogpt_platform/backend/backend/copilot/config_test.py
Normal file
38
autogpt_platform/backend/backend/copilot/config_test.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Unit tests for ChatConfig."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
|
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||||
|
# override the explicit constructor values we pass in each test.
|
||||||
|
_E2B_ENV_VARS = (
|
||||||
|
"CHAT_USE_E2B_SANDBOX",
|
||||||
|
"CHAT_E2B_API_KEY",
|
||||||
|
"E2B_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
for var in _E2B_ENV_VARS:
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2BActive:
|
||||||
|
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||||
|
|
||||||
|
def test_both_enabled_and_key_present_returns_true(self):
|
||||||
|
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
|
||||||
|
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
|
||||||
|
assert cfg.e2b_active is True
|
||||||
|
|
||||||
|
def test_enabled_but_missing_key_returns_false(self):
|
||||||
|
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
|
||||||
|
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
|
||||||
|
assert cfg.e2b_active is False
|
||||||
|
|
||||||
|
def test_disabled_returns_false(self):
|
||||||
|
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
|
||||||
|
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
|
||||||
|
assert cfg.e2b_active is False
|
||||||
@@ -6,6 +6,32 @@
|
|||||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||||
|
|
||||||
|
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||||
|
# Used to distinguish CoPilot-generated records from real graph execution records
|
||||||
|
# in PendingHumanReview and other tables.
|
||||||
|
COPILOT_SYNTHETIC_ID_PREFIX = "copilot-"
|
||||||
|
|
||||||
|
# Sub-prefixes for session-scoped and node-scoped synthetic IDs.
|
||||||
|
COPILOT_SESSION_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}session-"
|
||||||
|
COPILOT_NODE_PREFIX = f"{COPILOT_SYNTHETIC_ID_PREFIX}node-"
|
||||||
|
|
||||||
|
# Separator used in synthetic node_exec_id to encode node_id.
|
||||||
|
# Format: "{node_id}:{random_hex}" — extract node_id via rsplit(":", 1)[0]
|
||||||
|
COPILOT_NODE_EXEC_ID_SEPARATOR = ":"
|
||||||
|
|
||||||
# Compaction notice messages shown to users.
|
# Compaction notice messages shown to users.
|
||||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||||
COMPACTION_TOOL_NAME = "context_compaction"
|
COMPACTION_TOOL_NAME = "context_compaction"
|
||||||
|
|
||||||
|
|
||||||
|
def is_copilot_synthetic_id(id_value: str) -> bool:
|
||||||
|
"""Check if an ID is a CoPilot synthetic ID (not from a real graph execution)."""
|
||||||
|
return id_value.startswith(COPILOT_SYNTHETIC_ID_PREFIX)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||||
|
"""Extract node_id from a synthetic node_exec_id.
|
||||||
|
|
||||||
|
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||||
|
"""
|
||||||
|
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||||
|
|||||||
128
autogpt_platform/backend/backend/copilot/context.py
Normal file
128
autogpt_platform/backend/backend/copilot/context.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Shared execution context for copilot SDK tool handlers.
|
||||||
|
|
||||||
|
All context variables and their accessors live here so that
|
||||||
|
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
|
||||||
|
without creating circular dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.data.db_accessors import workspace_db
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from e2b import AsyncSandbox
|
||||||
|
|
||||||
|
# Allowed base directory for the Read tool.
|
||||||
|
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||||
|
|
||||||
|
# Encoded project-directory name for the current session (e.g.
|
||||||
|
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||||
|
# validation can scope tool-results reads to the current session.
|
||||||
|
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||||
|
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||||
|
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
|
"current_session", default=None
|
||||||
|
)
|
||||||
|
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||||
|
"_current_sandbox", default=None
|
||||||
|
)
|
||||||
|
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||||
|
"""Encode a working directory path the same way the Claude CLI does."""
|
||||||
|
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||||
|
|
||||||
|
|
||||||
|
def set_execution_context(
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
sandbox: "AsyncSandbox | None" = None,
|
||||||
|
sdk_cwd: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
_current_session.set(session)
|
||||||
|
_current_sandbox.set(sandbox)
|
||||||
|
_current_sdk_cwd.set(sdk_cwd or "")
|
||||||
|
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||||
|
|
||||||
|
|
||||||
|
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||||
|
"""Return the current (user_id, session) pair for the active request."""
|
||||||
|
return _current_user_id.get(), _current_session.get()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||||
|
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||||
|
return _current_sandbox.get()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sdk_cwd() -> str:
|
||||||
|
"""Return the SDK working directory for the current session (empty string if unset)."""
|
||||||
|
return _current_sdk_cwd.get()
|
||||||
|
|
||||||
|
|
||||||
|
E2B_WORKDIR = "/home/user"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_sandbox_path(path: str) -> str:
|
||||||
|
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||||
|
|
||||||
|
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||||
|
"""
|
||||||
|
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||||
|
normalized = os.path.normpath(candidate)
|
||||||
|
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||||
|
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||||
|
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||||
|
|
||||||
|
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||||
|
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||||
|
``tools/__init__`` import chain.
|
||||||
|
"""
|
||||||
|
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||||
|
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||||
|
"""Return True if *path* is within an allowed host-filesystem location.
|
||||||
|
|
||||||
|
Allowed:
|
||||||
|
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||||
|
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||||
|
"""
|
||||||
|
if not path:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if path.startswith("~"):
|
||||||
|
resolved = os.path.realpath(os.path.expanduser(path))
|
||||||
|
elif not os.path.isabs(path) and sdk_cwd:
|
||||||
|
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||||
|
else:
|
||||||
|
resolved = os.path.realpath(path)
|
||||||
|
|
||||||
|
if sdk_cwd:
|
||||||
|
norm_cwd = os.path.realpath(sdk_cwd)
|
||||||
|
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||||
|
return True
|
||||||
|
|
||||||
|
encoded = _current_project_dir.get("")
|
||||||
|
if encoded:
|
||||||
|
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||||
|
if resolved == tool_results_dir or resolved.startswith(
|
||||||
|
tool_results_dir + os.sep
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
163
autogpt_platform/backend/backend/copilot/context_test.py
Normal file
163
autogpt_platform/backend/backend/copilot/context_test.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Tests for context.py — execution context variables and path helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.copilot.context import (
|
||||||
|
_SDK_PROJECTS_DIR,
|
||||||
|
_current_project_dir,
|
||||||
|
get_current_sandbox,
|
||||||
|
get_execution_context,
|
||||||
|
get_sdk_cwd,
|
||||||
|
is_allowed_local_path,
|
||||||
|
resolve_sandbox_path,
|
||||||
|
set_execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session() -> MagicMock:
|
||||||
|
s = MagicMock()
|
||||||
|
s.session_id = "test-session"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Context variable getters
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_execution_context_defaults():
|
||||||
|
"""get_execution_context returns (None, session) when user_id is not set."""
|
||||||
|
set_execution_context(None, _make_session())
|
||||||
|
user_id, session = get_execution_context()
|
||||||
|
assert user_id is None
|
||||||
|
assert session is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_and_get_execution_context():
|
||||||
|
"""set_execution_context stores user_id and session."""
|
||||||
|
mock_session = _make_session()
|
||||||
|
set_execution_context("user-abc", mock_session)
|
||||||
|
user_id, session = get_execution_context()
|
||||||
|
assert user_id == "user-abc"
|
||||||
|
assert session is mock_session
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_sandbox_none_by_default():
|
||||||
|
"""get_current_sandbox returns None when no sandbox is set."""
|
||||||
|
set_execution_context("u1", _make_session(), sandbox=None)
|
||||||
|
assert get_current_sandbox() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_sandbox_returns_set_value():
|
||||||
|
"""get_current_sandbox returns the sandbox set via set_execution_context."""
|
||||||
|
mock_sandbox = MagicMock()
|
||||||
|
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
|
||||||
|
assert get_current_sandbox() is mock_sandbox
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sdk_cwd_empty_when_not_set():
|
||||||
|
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||||
|
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||||
|
assert get_sdk_cwd() == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sdk_cwd_returns_set_value():
|
||||||
|
"""get_sdk_cwd returns the value set via set_execution_context."""
|
||||||
|
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
|
||||||
|
assert get_sdk_cwd() == "/tmp/copilot-test"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# is_allowed_local_path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_empty():
|
||||||
|
assert not is_allowed_local_path("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_inside_sdk_cwd():
|
||||||
|
with tempfile.TemporaryDirectory() as cwd:
|
||||||
|
path = os.path.join(cwd, "file.txt")
|
||||||
|
assert is_allowed_local_path(path, cwd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_sdk_cwd_itself():
|
||||||
|
with tempfile.TemporaryDirectory() as cwd:
|
||||||
|
assert is_allowed_local_path(cwd, cwd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_outside_sdk_cwd():
|
||||||
|
with tempfile.TemporaryDirectory() as cwd:
|
||||||
|
assert not is_allowed_local_path("/etc/passwd", cwd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||||
|
"""Without sdk_cwd or project_dir, all paths are rejected."""
|
||||||
|
_current_project_dir.set("")
|
||||||
|
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_tool_results_dir():
|
||||||
|
"""Files under the tool-results directory for the current project are allowed."""
|
||||||
|
encoded = "test-encoded-dir"
|
||||||
|
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||||
|
path = os.path.join(tool_results_dir, "output.txt")
|
||||||
|
|
||||||
|
_current_project_dir.set(encoded)
|
||||||
|
try:
|
||||||
|
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||||
|
finally:
|
||||||
|
_current_project_dir.set("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||||
|
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||||
|
encoded = "test-encoded-dir"
|
||||||
|
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||||
|
|
||||||
|
_current_project_dir.set(encoded)
|
||||||
|
try:
|
||||||
|
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
|
||||||
|
finally:
|
||||||
|
_current_project_dir.set("")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# resolve_sandbox_path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_sandbox_path_absolute_valid():
|
||||||
|
assert (
|
||||||
|
resolve_sandbox_path("/home/user/project/main.py")
|
||||||
|
== "/home/user/project/main.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_sandbox_path_relative():
|
||||||
|
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_sandbox_path_workdir_itself():
|
||||||
|
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_sandbox_path_normalizes_dots():
|
||||||
|
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_sandbox_path_escape_raises():
|
||||||
|
with pytest.raises(ValueError, match="/home/user"):
|
||||||
|
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||||
|
with pytest.raises(ValueError, match="/home/user"):
|
||||||
|
resolve_sandbox_path("/etc/passwd")
|
||||||
138
autogpt_platform/backend/backend/copilot/optimize_blocks.py
Normal file
138
autogpt_platform/backend/backend/copilot/optimize_blocks.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Scheduler job to generate LLM-optimized block descriptions.
|
||||||
|
|
||||||
|
Runs periodically to rewrite block descriptions into concise, actionable
|
||||||
|
summaries that help the copilot LLM pick the right blocks during agent
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from backend.blocks import get_blocks
|
||||||
|
from backend.util.clients import get_database_manager_client, get_openai_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"You are a technical writer for an automation platform. "
|
||||||
|
"Rewrite the following block description to be concise (under 50 words), "
|
||||||
|
"informative, and actionable. Focus on what the block does and when to "
|
||||||
|
"use it. Output ONLY the rewritten description, nothing else. "
|
||||||
|
"Do not use markdown formatting."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rate-limit delay between sequential LLM calls (seconds)
|
||||||
|
_RATE_LIMIT_DELAY = 0.5
|
||||||
|
# Maximum tokens for optimized description generation
|
||||||
|
_MAX_DESCRIPTION_TOKENS = 150
|
||||||
|
# Model for generating optimized descriptions (fast, cheap)
|
||||||
|
_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
|
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
|
||||||
|
"""Call the shared OpenAI client to rewrite each block description."""
|
||||||
|
client = get_openai_client()
|
||||||
|
if client is None:
|
||||||
|
logger.error(
|
||||||
|
"No OpenAI client configured, skipping block description optimization"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
results: dict[str, str] = {}
|
||||||
|
for block in blocks:
|
||||||
|
block_id = block["id"]
|
||||||
|
block_name = block["name"]
|
||||||
|
description = block["description"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=_MODEL,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Block name: {block_name}\nDescription: {description}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=_MAX_DESCRIPTION_TOKENS,
|
||||||
|
)
|
||||||
|
optimized = (response.choices[0].message.content or "").strip()
|
||||||
|
if optimized:
|
||||||
|
results[block_id] = optimized
|
||||||
|
logger.debug("Optimized description for %s", block_name)
|
||||||
|
else:
|
||||||
|
logger.warning("Empty response for block %s", block_name)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to optimize description for %s", block_name, exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(_RATE_LIMIT_DELAY)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_block_descriptions() -> dict[str, int]:
|
||||||
|
"""Generate optimized descriptions for blocks that don't have one yet.
|
||||||
|
|
||||||
|
Uses the shared OpenAI client to rewrite block descriptions into concise
|
||||||
|
summaries suitable for agent generation prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with counts: processed, success, failed, skipped.
|
||||||
|
"""
|
||||||
|
db_client = get_database_manager_client()
|
||||||
|
|
||||||
|
blocks = db_client.get_blocks_needing_optimization()
|
||||||
|
if not blocks:
|
||||||
|
logger.info("All blocks already have optimized descriptions")
|
||||||
|
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
|
||||||
|
|
||||||
|
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
|
||||||
|
|
||||||
|
non_empty = [b for b in blocks if b.get("description", "").strip()]
|
||||||
|
skipped = len(blocks) - len(non_empty)
|
||||||
|
|
||||||
|
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"processed": len(non_empty),
|
||||||
|
"success": len(new_descriptions),
|
||||||
|
"failed": len(non_empty) - len(new_descriptions),
|
||||||
|
"skipped": skipped,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Block description optimization complete: "
|
||||||
|
"%d/%d succeeded, %d failed, %d skipped",
|
||||||
|
stats["success"],
|
||||||
|
stats["processed"],
|
||||||
|
stats["failed"],
|
||||||
|
stats["skipped"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_descriptions:
|
||||||
|
for block_id, optimized in new_descriptions.items():
|
||||||
|
db_client.update_block_optimized_description(block_id, optimized)
|
||||||
|
|
||||||
|
# Update in-memory descriptions first so the cache rebuilds with fresh data.
|
||||||
|
try:
|
||||||
|
block_classes = get_blocks()
|
||||||
|
for block_id, optimized in new_descriptions.items():
|
||||||
|
if block_id in block_classes:
|
||||||
|
block_classes[block_id]._optimized_description = optimized
|
||||||
|
logger.info(
|
||||||
|
"Updated %d in-memory block descriptions", len(new_descriptions)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Could not update in-memory block descriptions", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.copilot.tools.agent_generator.blocks import (
|
||||||
|
reset_block_caches, # local to avoid circular import
|
||||||
|
)
|
||||||
|
|
||||||
|
reset_block_caches()
|
||||||
|
|
||||||
|
return stats
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
"""Unit tests for optimize_blocks._optimize_descriptions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client_response(text: str) -> MagicMock:
|
||||||
|
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
|
||||||
|
choice = MagicMock()
|
||||||
|
choice.message.content = text
|
||||||
|
response = MagicMock()
|
||||||
|
response.choices = [choice]
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.get_event_loop().run_until_complete(coro)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizeDescriptions:
|
||||||
|
"""Tests for _optimize_descriptions async function."""
|
||||||
|
|
||||||
|
def test_returns_empty_when_no_client(self):
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
|
||||||
|
):
|
||||||
|
result = _run(
|
||||||
|
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_success_single_block(self):
|
||||||
|
client = MagicMock()
|
||||||
|
client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_make_client_response("Short desc.")
|
||||||
|
)
|
||||||
|
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = _run(_optimize_descriptions(blocks))
|
||||||
|
|
||||||
|
assert result == {"b1": "Short desc."}
|
||||||
|
client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
def test_skips_block_on_exception(self):
|
||||||
|
client = MagicMock()
|
||||||
|
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
|
||||||
|
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = _run(_optimize_descriptions(blocks))
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_sleeps_between_blocks(self):
|
||||||
|
client = MagicMock()
|
||||||
|
client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_make_client_response("desc")
|
||||||
|
)
|
||||||
|
blocks = [
|
||||||
|
{"id": "b1", "name": "B1", "description": "d1"},
|
||||||
|
{"id": "b2", "name": "B2", "description": "d2"},
|
||||||
|
]
|
||||||
|
sleep_mock = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||||
|
),
|
||||||
|
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
|
||||||
|
):
|
||||||
|
_run(_optimize_descriptions(blocks))
|
||||||
|
|
||||||
|
assert sleep_mock.call_count == 2
|
||||||
|
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)
|
||||||
@@ -26,6 +26,70 @@ your message as a Markdown link or image:
|
|||||||
The `download_url` field in the `write_workspace_file` response is already
|
The `download_url` field in the `write_workspace_file` response is already
|
||||||
in the correct format — paste it directly after the `(` in the Markdown.
|
in the correct format — paste it directly after the `(` in the Markdown.
|
||||||
|
|
||||||
|
### Passing file content to tools — @@agptfile: references
|
||||||
|
Instead of copying large file contents into a tool argument, pass a file
|
||||||
|
reference and the platform will load the content for you.
|
||||||
|
|
||||||
|
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||||
|
|
||||||
|
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||||
|
- `workspace://<file_id>` — workspace file by ID
|
||||||
|
- `workspace:///<path>` — workspace file by virtual path
|
||||||
|
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||||
|
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||||
|
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||||
|
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```
|
||||||
|
@@agptfile:workspace://abc123
|
||||||
|
@@agptfile:workspace://abc123[10-50]
|
||||||
|
@@agptfile:workspace:///reports/q1.md
|
||||||
|
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||||
|
@@agptfile:/home/user/script.py
|
||||||
|
```
|
||||||
|
|
||||||
|
You can embed a reference inside any string argument, or use it as the entire
|
||||||
|
value. Multiple references in one argument are all expanded.
|
||||||
|
|
||||||
|
**Structured data**: When the **entire** argument value is a single file
|
||||||
|
reference (no surrounding text), the platform automatically parses the file
|
||||||
|
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||||
|
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||||
|
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||||
|
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||||
|
unrecognised or parsing fails, the content is returned as a plain string.
|
||||||
|
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||||
|
|
||||||
|
**Type coercion**: The platform also coerces expanded values to match the
|
||||||
|
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||||
|
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||||
|
|
||||||
|
### Media file inputs (format: "file")
|
||||||
|
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||||
|
These fields accept:
|
||||||
|
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||||
|
for large files (images, videos, PDFs). The platform passes the reference
|
||||||
|
directly to the block without reading the content into memory.
|
||||||
|
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||||
|
small files only.
|
||||||
|
|
||||||
|
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||||
|
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||||
|
payloads in tool arguments and preserves binary content (images, videos)
|
||||||
|
that would be corrupted by text encoding.
|
||||||
|
|
||||||
|
Example — committing an image file to GitHub:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"files": [{
|
||||||
|
"path": "docs/hero.png",
|
||||||
|
"content": "workspace://abc123#image/png",
|
||||||
|
"operation": "upsert"
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Sub-agent tasks
|
### Sub-agent tasks
|
||||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||||
All tasks must run in the foreground.
|
All tasks must run in the foreground.
|
||||||
|
|||||||
@@ -3,12 +3,45 @@
|
|||||||
This module provides the integration layer between the Claude Agent SDK
|
This module provides the integration layer between the Claude Agent SDK
|
||||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||||
|
|
||||||
|
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||||
|
circular import cycle::
|
||||||
|
|
||||||
|
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||||
|
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||||
|
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||||
|
|
||||||
|
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||||
|
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||||
|
function scope without a larger refactor (moving tool-name registration
|
||||||
|
to a separate lightweight module). The lazy-import pattern here is the
|
||||||
|
least invasive way to break the cycle while keeping module-level constants
|
||||||
|
intact.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .service import stream_chat_completion_sdk
|
from typing import Any
|
||||||
from .tool_adapter import create_copilot_mcp_server
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"stream_chat_completion_sdk",
|
"stream_chat_completion_sdk",
|
||||||
"create_copilot_mcp_server",
|
"create_copilot_mcp_server",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||||
|
# pair so new exports can be added without touching __getattr__ itself.
|
||||||
|
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||||
|
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||||
|
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
entry = _LAZY_IMPORTS.get(name)
|
||||||
|
if entry is not None:
|
||||||
|
module_path, attr = entry
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
module = importlib.import_module(module_path, package=__name__)
|
||||||
|
value = getattr(module, attr)
|
||||||
|
globals()[name] = value
|
||||||
|
return value
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
## Agent Generation Guide
|
||||||
|
|
||||||
|
You can create, edit, and customize agents directly. You ARE the brain —
|
||||||
|
generate the agent JSON yourself using block schemas, then validate and save.
|
||||||
|
|
||||||
|
### Workflow for Creating/Editing Agents
|
||||||
|
|
||||||
|
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||||
|
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||||
|
and full input/output schemas.
|
||||||
|
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||||
|
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||||
|
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||||
|
- Use block IDs from step 1 as `block_id` in nodes
|
||||||
|
- Wire outputs to inputs using links
|
||||||
|
- Set design-time config in `input_default`
|
||||||
|
- Use `AgentInputBlock` for values the user provides at runtime
|
||||||
|
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||||
|
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||||
|
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||||
|
for errors
|
||||||
|
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||||
|
or fix manually based on the error descriptions. Iterate until valid.
|
||||||
|
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||||
|
the final `agent_json`
|
||||||
|
|
||||||
|
### Agent JSON Structure
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "<UUID v4>", // auto-generated if omitted
|
||||||
|
"version": 1,
|
||||||
|
"is_active": true,
|
||||||
|
"name": "Agent Name",
|
||||||
|
"description": "What the agent does",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "<UUID v4>",
|
||||||
|
"block_id": "<block UUID from find_block>",
|
||||||
|
"input_default": {
|
||||||
|
"field_name": "design-time value"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"position": {"x": 0, "y": 0},
|
||||||
|
"customized_name": "Optional display name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"id": "<UUID v4>",
|
||||||
|
"source_id": "<source node UUID>",
|
||||||
|
"source_name": "output_field_name",
|
||||||
|
"sink_id": "<sink node UUID>",
|
||||||
|
"sink_name": "input_field_name",
|
||||||
|
"is_static": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### REQUIRED: AgentInputBlock and AgentOutputBlock
|
||||||
|
|
||||||
|
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
|
||||||
|
These define the agent's interface — what it accepts and what it produces.
|
||||||
|
|
||||||
|
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||||
|
- Defines a user-facing input field on the agent
|
||||||
|
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||||
|
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||||
|
- Output: `result` — the user-provided value at runtime
|
||||||
|
- Create one AgentInputBlock per distinct input the agent needs
|
||||||
|
|
||||||
|
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||||
|
- Defines a user-facing output displayed after the agent runs
|
||||||
|
- Required `input_default` fields: `name` (str)
|
||||||
|
- The `value` input should be linked from another block's output
|
||||||
|
- Optional: `title`, `description`, `format` (Jinja2 template)
|
||||||
|
- Create one AgentOutputBlock per distinct result to show the user
|
||||||
|
|
||||||
|
Without these blocks, the agent has no interface and the user cannot provide
|
||||||
|
inputs or see outputs. NEVER skip them.
|
||||||
|
|
||||||
|
### Key Rules
|
||||||
|
|
||||||
|
- **Name & description**: Include `name` and `description` in the agent JSON
|
||||||
|
when creating a new agent, or when editing and the agent's purpose changed.
|
||||||
|
Without these the agent gets a generic default name.
|
||||||
|
- **Design-time vs runtime**: `input_default` = values known at build time.
|
||||||
|
For user-provided values, create an `AgentInputBlock` node and link its
|
||||||
|
output to the consuming block's input.
|
||||||
|
- **Credentials**: Do NOT require credentials upfront. Users configure
|
||||||
|
credentials later in the platform UI after the agent is saved.
|
||||||
|
- **Node spacing**: Position nodes with at least 800 X-units between them.
|
||||||
|
- **Nested properties**: Use `parentField_#_childField` notation in link
|
||||||
|
sink_name/source_name to access nested object fields.
|
||||||
|
- **is_static links**: Set `is_static: true` when the link carries a
|
||||||
|
design-time constant (matches a field in inputSchema with a default).
|
||||||
|
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
|
||||||
|
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
|
||||||
|
literal braces in prompt strings — single `{` and `}` are for
|
||||||
|
template variables.
|
||||||
|
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
|
||||||
|
`graph_version` in input_default, and wire inputs/outputs to match
|
||||||
|
the sub-agent's schema.
|
||||||
|
|
||||||
|
### Using Sub-Agents (AgentExecutorBlock)
|
||||||
|
|
||||||
|
To compose agents using other agents as sub-agents:
|
||||||
|
1. Call `find_library_agent` to find the sub-agent — the response includes
|
||||||
|
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
|
||||||
|
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
|
||||||
|
3. Set `input_default`:
|
||||||
|
- `graph_id`: from the library agent's `graph_id`
|
||||||
|
- `graph_version`: from the library agent's `graph_version`
|
||||||
|
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
|
||||||
|
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
|
||||||
|
- `user_id`: leave as `""` (filled at runtime)
|
||||||
|
- `inputs`: `{}` (populated by links at runtime)
|
||||||
|
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
|
||||||
|
property names (e.g., if input_schema has a `"url"` property, use
|
||||||
|
`"url"` as the sink_name)
|
||||||
|
5. Wire outputs: link from source names matching the sub-agent's
|
||||||
|
`output_schema` property names
|
||||||
|
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
|
||||||
|
the library agent IDs used, so the fixer can validate schemas
|
||||||
|
|
||||||
|
### Using MCP Tools (MCPToolBlock)
|
||||||
|
|
||||||
|
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||||
|
1. The user must specify which MCP server URL and tool name they want
|
||||||
|
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||||
|
3. Set `input_default`:
|
||||||
|
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
|
||||||
|
- `selected_tool`: the tool name on that server
|
||||||
|
- `tool_input_schema`: JSON Schema for the tool's inputs
|
||||||
|
- `tool_arguments`: `{}` (populated by links or hardcoded values)
|
||||||
|
4. The block requires MCP credentials — the user configures these in the
|
||||||
|
platform UI after the agent is saved
|
||||||
|
5. Wire inputs using the tool argument field name directly as the sink_name
|
||||||
|
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
|
||||||
|
automatically collects top-level fields matching tool_input_schema into
|
||||||
|
tool_arguments.
|
||||||
|
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||||
|
|
||||||
|
### Example: Simple AI Text Processor
|
||||||
|
|
||||||
|
A minimal agent with input, processing, and output:
|
||||||
|
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
|
||||||
|
input_default: {"name": "user_text", "title": "Text to process"},
|
||||||
|
output: "result")
|
||||||
|
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
|
||||||
|
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
|
||||||
|
input_default: {"name": "summary", "title": "Summary"},
|
||||||
|
input: "value" linked from Node 2's output)
|
||||||
@@ -11,7 +11,7 @@ persistence, and the ``CompactionTracker`` state machine.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||||
from ..model import ChatMessage, ChatSession
|
from ..model import ChatMessage, ChatSession
|
||||||
@@ -27,6 +27,19 @@ from ..response_model import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompactionResult:
|
||||||
|
"""Result of emit_end_if_ready — bundles events with compaction metadata.
|
||||||
|
|
||||||
|
Eliminates the need for separate ``compaction_just_ended`` checks,
|
||||||
|
preventing TOCTOU races between the emit call and the flag read.
|
||||||
|
"""
|
||||||
|
|
||||||
|
events: list[StreamBaseResponse] = field(default_factory=list)
|
||||||
|
just_ended: bool = False
|
||||||
|
transcript_path: str = ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Event builders (private — use CompactionTracker or compaction_events)
|
# Event builders (private — use CompactionTracker or compaction_events)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -177,11 +190,22 @@ class CompactionTracker:
|
|||||||
self._start_emitted = False
|
self._start_emitted = False
|
||||||
self._done = False
|
self._done = False
|
||||||
self._tool_call_id = ""
|
self._tool_call_id = ""
|
||||||
|
self._transcript_path: str = ""
|
||||||
|
|
||||||
@property
|
def on_compact(self, transcript_path: str = "") -> None:
|
||||||
def on_compact(self) -> Callable[[], None]:
|
"""Callback for the PreCompact hook. Stores transcript_path."""
|
||||||
"""Callback for the PreCompact hook."""
|
if (
|
||||||
return self._compact_start.set
|
self._transcript_path
|
||||||
|
and transcript_path
|
||||||
|
and self._transcript_path != transcript_path
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"[Compaction] Overwriting transcript_path %s -> %s",
|
||||||
|
self._transcript_path,
|
||||||
|
transcript_path,
|
||||||
|
)
|
||||||
|
self._transcript_path = transcript_path
|
||||||
|
self._compact_start.set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Pre-query compaction
|
# Pre-query compaction
|
||||||
@@ -201,6 +225,7 @@ class CompactionTracker:
|
|||||||
self._done = False
|
self._done = False
|
||||||
self._start_emitted = False
|
self._start_emitted = False
|
||||||
self._tool_call_id = ""
|
self._tool_call_id = ""
|
||||||
|
self._transcript_path = ""
|
||||||
|
|
||||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||||
@@ -211,15 +236,20 @@ class CompactionTracker:
|
|||||||
return _start_events(self._tool_call_id)
|
return _start_events(self._tool_call_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
async def emit_end_if_ready(self, session: ChatSession) -> CompactionResult:
|
||||||
"""If compaction is in progress, emit end events and persist."""
|
"""If compaction is in progress, emit end events and persist.
|
||||||
|
|
||||||
|
Returns a ``CompactionResult`` with ``just_ended=True`` and the
|
||||||
|
captured ``transcript_path`` when a compaction cycle completes.
|
||||||
|
This avoids a separate flag check (TOCTOU-safe).
|
||||||
|
"""
|
||||||
# Yield so pending hook tasks can set compact_start
|
# Yield so pending hook tasks can set compact_start
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
if self._done:
|
if self._done:
|
||||||
return []
|
return CompactionResult()
|
||||||
if not self._start_emitted and not self._compact_start.is_set():
|
if not self._start_emitted and not self._compact_start.is_set():
|
||||||
return []
|
return CompactionResult()
|
||||||
|
|
||||||
if self._start_emitted:
|
if self._start_emitted:
|
||||||
# Close the open spinner
|
# Close the open spinner
|
||||||
@@ -232,8 +262,12 @@ class CompactionTracker:
|
|||||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
transcript_path = self._transcript_path
|
||||||
self._compact_start.clear()
|
self._compact_start.clear()
|
||||||
self._start_emitted = False
|
self._start_emitted = False
|
||||||
self._done = True
|
self._done = True
|
||||||
|
self._transcript_path = ""
|
||||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||||
return done_events
|
return CompactionResult(
|
||||||
|
events=done_events, just_ended=True, transcript_path=transcript_path
|
||||||
|
)
|
||||||
|
|||||||
@@ -195,10 +195,11 @@ class TestCompactionTracker:
|
|||||||
session = _make_session()
|
session = _make_session()
|
||||||
tracker.on_compact()
|
tracker.on_compact()
|
||||||
tracker.emit_start_if_ready()
|
tracker.emit_start_if_ready()
|
||||||
evts = await tracker.emit_end_if_ready(session)
|
result = await tracker.emit_end_if_ready(session)
|
||||||
assert len(evts) == 2
|
assert result.just_ended is True
|
||||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
assert len(result.events) == 2
|
||||||
assert isinstance(evts[1], StreamFinishStep)
|
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||||
|
assert isinstance(result.events[1], StreamFinishStep)
|
||||||
# Should persist
|
# Should persist
|
||||||
assert len(session.messages) == 2
|
assert len(session.messages) == 2
|
||||||
|
|
||||||
@@ -210,28 +211,32 @@ class TestCompactionTracker:
|
|||||||
session = _make_session()
|
session = _make_session()
|
||||||
tracker.on_compact()
|
tracker.on_compact()
|
||||||
# Don't call emit_start_if_ready
|
# Don't call emit_start_if_ready
|
||||||
evts = await tracker.emit_end_if_ready(session)
|
result = await tracker.emit_end_if_ready(session)
|
||||||
assert len(evts) == 5 # Full self-contained event
|
assert result.just_ended is True
|
||||||
assert isinstance(evts[0], StreamStartStep)
|
assert len(result.events) == 5 # Full self-contained event
|
||||||
|
assert isinstance(result.events[0], StreamStartStep)
|
||||||
assert len(session.messages) == 2
|
assert len(session.messages) == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_emit_end_no_op_when_done(self):
|
async def test_emit_end_no_op_when_no_new_compaction(self):
|
||||||
tracker = CompactionTracker()
|
tracker = CompactionTracker()
|
||||||
session = _make_session()
|
session = _make_session()
|
||||||
tracker.on_compact()
|
tracker.on_compact()
|
||||||
tracker.emit_start_if_ready()
|
tracker.emit_start_if_ready()
|
||||||
await tracker.emit_end_if_ready(session)
|
result1 = await tracker.emit_end_if_ready(session)
|
||||||
# Second call should be no-op
|
assert result1.just_ended is True
|
||||||
evts = await tracker.emit_end_if_ready(session)
|
# Second call should be no-op (no new on_compact)
|
||||||
assert evts == []
|
result2 = await tracker.emit_end_if_ready(session)
|
||||||
|
assert result2.just_ended is False
|
||||||
|
assert result2.events == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||||
tracker = CompactionTracker()
|
tracker = CompactionTracker()
|
||||||
session = _make_session()
|
session = _make_session()
|
||||||
evts = await tracker.emit_end_if_ready(session)
|
result = await tracker.emit_end_if_ready(session)
|
||||||
assert evts == []
|
assert result.just_ended is False
|
||||||
|
assert result.events == []
|
||||||
|
|
||||||
def test_emit_pre_query(self):
|
def test_emit_pre_query(self):
|
||||||
tracker = CompactionTracker()
|
tracker = CompactionTracker()
|
||||||
@@ -246,20 +251,29 @@ class TestCompactionTracker:
|
|||||||
tracker._done = True
|
tracker._done = True
|
||||||
tracker._start_emitted = True
|
tracker._start_emitted = True
|
||||||
tracker._tool_call_id = "old"
|
tracker._tool_call_id = "old"
|
||||||
|
tracker._transcript_path = "/some/path"
|
||||||
tracker.reset_for_query()
|
tracker.reset_for_query()
|
||||||
assert tracker._done is False
|
assert tracker._done is False
|
||||||
assert tracker._start_emitted is False
|
assert tracker._start_emitted is False
|
||||||
assert tracker._tool_call_id == ""
|
assert tracker._tool_call_id == ""
|
||||||
|
assert tracker._transcript_path == ""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pre_query_blocks_sdk_compaction(self):
|
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
|
||||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
"""After pre-query compaction, SDK compaction is blocked until
|
||||||
|
reset_for_query is called."""
|
||||||
tracker = CompactionTracker()
|
tracker = CompactionTracker()
|
||||||
session = _make_session()
|
session = _make_session()
|
||||||
tracker.emit_pre_query(session)
|
tracker.emit_pre_query(session)
|
||||||
tracker.on_compact()
|
tracker.on_compact()
|
||||||
|
# _done is True so emit_start_if_ready is blocked
|
||||||
evts = tracker.emit_start_if_ready()
|
evts = tracker.emit_start_if_ready()
|
||||||
assert evts == [] # _done blocks it
|
assert evts == []
|
||||||
|
# Reset clears _done, allowing subsequent compaction
|
||||||
|
tracker.reset_for_query()
|
||||||
|
tracker.on_compact()
|
||||||
|
evts = tracker.emit_start_if_ready()
|
||||||
|
assert len(evts) == 3
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_allows_new_compaction(self):
|
async def test_reset_allows_new_compaction(self):
|
||||||
@@ -279,9 +293,9 @@ class TestCompactionTracker:
|
|||||||
session = _make_session()
|
session = _make_session()
|
||||||
tracker.on_compact()
|
tracker.on_compact()
|
||||||
start_evts = tracker.emit_start_if_ready()
|
start_evts = tracker.emit_start_if_ready()
|
||||||
end_evts = await tracker.emit_end_if_ready(session)
|
result = await tracker.emit_end_if_ready(session)
|
||||||
start_evt = start_evts[1]
|
start_evt = start_evts[1]
|
||||||
end_evt = end_evts[0]
|
end_evt = result.events[0]
|
||||||
assert isinstance(start_evt, StreamToolInputStart)
|
assert isinstance(start_evt, StreamToolInputStart)
|
||||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||||
assert start_evt.toolCallId == end_evt.toolCallId
|
assert start_evt.toolCallId == end_evt.toolCallId
|
||||||
@@ -289,3 +303,105 @@ class TestCompactionTracker:
|
|||||||
tool_calls = session.messages[0].tool_calls
|
tool_calls = session.messages[0].tool_calls
|
||||||
assert tool_calls is not None
|
assert tool_calls is not None
|
||||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_compactions_within_query(self):
|
||||||
|
"""Two mid-stream compactions within a single query both trigger."""
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = _make_session()
|
||||||
|
|
||||||
|
# First compaction cycle
|
||||||
|
tracker.on_compact("/path/1")
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
result1 = await tracker.emit_end_if_ready(session)
|
||||||
|
assert result1.just_ended is True
|
||||||
|
assert len(result1.events) == 2
|
||||||
|
assert result1.transcript_path == "/path/1"
|
||||||
|
|
||||||
|
# Second compaction cycle (should NOT be blocked — _done resets
|
||||||
|
# because emit_end_if_ready sets it True, but the next on_compact
|
||||||
|
# + emit_start_if_ready checks !_done which IS True now.
|
||||||
|
# So we need reset_for_query between queries, but within a single
|
||||||
|
# query multiple compactions work because _done blocks emit_start
|
||||||
|
# until the next message arrives, at which point emit_end detects it)
|
||||||
|
#
|
||||||
|
# Actually: _done=True blocks emit_start_if_ready, so we need
|
||||||
|
# the stream loop to reset. In practice service.py doesn't call
|
||||||
|
# reset between compactions within the same query — let's verify
|
||||||
|
# the actual behavior.
|
||||||
|
tracker.on_compact("/path/2")
|
||||||
|
# _done is True from first compaction, so start is blocked
|
||||||
|
start_evts = tracker.emit_start_if_ready()
|
||||||
|
assert start_evts == []
|
||||||
|
# But emit_end returns no-op because _done is True
|
||||||
|
result2 = await tracker.emit_end_if_ready(session)
|
||||||
|
assert result2.just_ended is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_compactions_with_intervening_message(self):
|
||||||
|
"""Multiple compactions work when the stream loop processes messages between them.
|
||||||
|
|
||||||
|
In the real service.py flow:
|
||||||
|
1. PreCompact fires → on_compact()
|
||||||
|
2. emit_start shows spinner
|
||||||
|
3. Next message arrives → emit_end completes compaction (_done=True)
|
||||||
|
4. Stream continues processing messages...
|
||||||
|
5. If a second PreCompact fires, _done=True blocks emit_start
|
||||||
|
6. But the next message triggers emit_end, which sees _done=True → no-op
|
||||||
|
7. The stream loop needs to detect this and handle accordingly
|
||||||
|
|
||||||
|
The actual flow for multiple compactions within a query requires
|
||||||
|
_done to be cleared between them. The service.py code uses
|
||||||
|
CompactionResult.just_ended to trigger replace_entries, and _done
|
||||||
|
stays True until reset_for_query.
|
||||||
|
"""
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = _make_session()
|
||||||
|
|
||||||
|
# First compaction
|
||||||
|
tracker.on_compact("/path/1")
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
result1 = await tracker.emit_end_if_ready(session)
|
||||||
|
assert result1.just_ended is True
|
||||||
|
assert result1.transcript_path == "/path/1"
|
||||||
|
|
||||||
|
# Simulate reset between queries
|
||||||
|
tracker.reset_for_query()
|
||||||
|
|
||||||
|
# Second compaction in new query
|
||||||
|
tracker.on_compact("/path/2")
|
||||||
|
start_evts = tracker.emit_start_if_ready()
|
||||||
|
assert len(start_evts) == 3
|
||||||
|
result2 = await tracker.emit_end_if_ready(session)
|
||||||
|
assert result2.just_ended is True
|
||||||
|
assert result2.transcript_path == "/path/2"
|
||||||
|
|
||||||
|
def test_on_compact_stores_transcript_path(self):
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
tracker.on_compact("/some/path.jsonl")
|
||||||
|
assert tracker._transcript_path == "/some/path.jsonl"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emit_end_returns_transcript_path(self):
|
||||||
|
"""CompactionResult includes the transcript_path from on_compact."""
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = _make_session()
|
||||||
|
tracker.on_compact("/my/session.jsonl")
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
result = await tracker.emit_end_if_ready(session)
|
||||||
|
assert result.just_ended is True
|
||||||
|
assert result.transcript_path == "/my/session.jsonl"
|
||||||
|
# transcript_path is cleared after emit_end
|
||||||
|
assert tracker._transcript_path == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emit_end_clears_transcript_path(self):
|
||||||
|
"""After emit_end, _transcript_path is reset so it doesn't leak to
|
||||||
|
subsequent non-compaction emit_end calls."""
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = _make_session()
|
||||||
|
tracker.on_compact("/first/path.jsonl")
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
await tracker.emit_end_if_ready(session)
|
||||||
|
# After compaction, _transcript_path is cleared
|
||||||
|
assert tracker._transcript_path == ""
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
|||||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -17,36 +15,23 @@ import os
|
|||||||
import shlex
|
import shlex
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
|
from backend.copilot.context import (
|
||||||
|
E2B_WORKDIR,
|
||||||
|
get_current_sandbox,
|
||||||
|
get_sdk_cwd,
|
||||||
|
is_allowed_local_path,
|
||||||
|
resolve_sandbox_path,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Lazy imports to break circular dependency with tool_adapter.
|
def _get_sandbox():
|
||||||
|
|
||||||
|
|
||||||
def _get_sandbox(): # type: ignore[return]
|
|
||||||
from .tool_adapter import get_current_sandbox # noqa: E402
|
|
||||||
|
|
||||||
return get_current_sandbox()
|
return get_current_sandbox()
|
||||||
|
|
||||||
|
|
||||||
def _is_allowed_local(path: str) -> bool:
|
def _is_allowed_local(path: str) -> bool:
|
||||||
from .tool_adapter import is_allowed_local_path # noqa: E402
|
return is_allowed_local_path(path, get_sdk_cwd())
|
||||||
|
|
||||||
return is_allowed_local_path(path)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_remote(path: str) -> str:
|
|
||||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
|
||||||
|
|
||||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
|
||||||
"""
|
|
||||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
|
||||||
normalized = os.path.normpath(candidate)
|
|
||||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
|
||||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
|
|
||||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||||
@@ -63,7 +48,7 @@ def _get_sandbox_and_path(
|
|||||||
if sandbox is None:
|
if sandbox is None:
|
||||||
return _mcp("No E2B sandbox available", error=True)
|
return _mcp("No E2B sandbox available", error=True)
|
||||||
try:
|
try:
|
||||||
remote = _resolve_remote(file_path)
|
remote = resolve_sandbox_path(file_path)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return _mcp(str(exc), error=True)
|
return _mcp(str(exc), error=True)
|
||||||
return sandbox, remote
|
return sandbox, remote
|
||||||
@@ -73,6 +58,7 @@ def _get_sandbox_and_path(
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||||
file_path: str = args.get("file_path", "")
|
file_path: str = args.get("file_path", "")
|
||||||
offset: int = max(0, int(args.get("offset", 0)))
|
offset: int = max(0, int(args.get("offset", 0)))
|
||||||
limit: int = max(1, int(args.get("limit", 2000)))
|
limit: int = max(1, int(args.get("limit", 2000)))
|
||||||
@@ -104,6 +90,7 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Write content to a sandbox file, creating parent directories as needed."""
|
||||||
file_path: str = args.get("file_path", "")
|
file_path: str = args.get("file_path", "")
|
||||||
content: str = args.get("content", "")
|
content: str = args.get("content", "")
|
||||||
|
|
||||||
@@ -127,6 +114,7 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Replace a substring in a sandbox file, with optional replace-all support."""
|
||||||
file_path: str = args.get("file_path", "")
|
file_path: str = args.get("file_path", "")
|
||||||
old_string: str = args.get("old_string", "")
|
old_string: str = args.get("old_string", "")
|
||||||
new_string: str = args.get("new_string", "")
|
new_string: str = args.get("new_string", "")
|
||||||
@@ -172,6 +160,7 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Find files matching a name pattern inside the sandbox using ``find``."""
|
||||||
pattern: str = args.get("pattern", "")
|
pattern: str = args.get("pattern", "")
|
||||||
path: str = args.get("path", "")
|
path: str = args.get("path", "")
|
||||||
|
|
||||||
@@ -183,7 +172,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return _mcp("No E2B sandbox available", error=True)
|
return _mcp("No E2B sandbox available", error=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return _mcp(str(exc), error=True)
|
return _mcp(str(exc), error=True)
|
||||||
|
|
||||||
@@ -198,6 +187,7 @@ async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
|
||||||
pattern: str = args.get("pattern", "")
|
pattern: str = args.get("pattern", "")
|
||||||
path: str = args.get("path", "")
|
path: str = args.get("path", "")
|
||||||
include: str = args.get("include", "")
|
include: str = args.get("include", "")
|
||||||
@@ -210,7 +200,7 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return _mcp("No E2B sandbox available", error=True)
|
return _mcp("No E2B sandbox available", error=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
|
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
return _mcp(str(exc), error=True)
|
return _mcp(str(exc), error=True)
|
||||||
|
|
||||||
@@ -238,7 +228,7 @@ def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
|||||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||||
try:
|
try:
|
||||||
with open(expanded) as fh:
|
with open(expanded, encoding="utf-8", errors="replace") as fh:
|
||||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||||
numbered = "".join(
|
numbered = "".join(
|
||||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||||
|
|||||||
@@ -7,59 +7,60 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .e2b_file_tools import _read_local, _resolve_remote
|
from backend.copilot.context import _current_project_dir
|
||||||
from .tool_adapter import _current_project_dir
|
|
||||||
|
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||||
|
|
||||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _resolve_remote — sandbox path normalisation & boundary enforcement
|
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestResolveRemote:
|
class TestResolveSandboxPath:
|
||||||
def test_relative_path_resolved(self):
|
def test_relative_path_resolved(self):
|
||||||
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
|
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||||
|
|
||||||
def test_absolute_within_sandbox(self):
|
def test_absolute_within_sandbox(self):
|
||||||
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
|
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||||
|
|
||||||
def test_workdir_itself(self):
|
def test_workdir_itself(self):
|
||||||
assert _resolve_remote("/home/user") == "/home/user"
|
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||||
|
|
||||||
def test_relative_dotslash(self):
|
def test_relative_dotslash(self):
|
||||||
assert _resolve_remote("./README.md") == "/home/user/README.md"
|
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||||
|
|
||||||
def test_traversal_blocked(self):
|
def test_traversal_blocked(self):
|
||||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||||
_resolve_remote("../../etc/passwd")
|
resolve_sandbox_path("../../etc/passwd")
|
||||||
|
|
||||||
def test_absolute_traversal_blocked(self):
|
def test_absolute_traversal_blocked(self):
|
||||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||||
_resolve_remote("/home/user/../../etc/passwd")
|
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||||
|
|
||||||
def test_absolute_outside_sandbox_blocked(self):
|
def test_absolute_outside_sandbox_blocked(self):
|
||||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||||
_resolve_remote("/etc/passwd")
|
resolve_sandbox_path("/etc/passwd")
|
||||||
|
|
||||||
def test_root_blocked(self):
|
def test_root_blocked(self):
|
||||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||||
_resolve_remote("/")
|
resolve_sandbox_path("/")
|
||||||
|
|
||||||
def test_home_other_user_blocked(self):
|
def test_home_other_user_blocked(self):
|
||||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||||
_resolve_remote("/home/other/file.txt")
|
resolve_sandbox_path("/home/other/file.txt")
|
||||||
|
|
||||||
def test_deep_nested_allowed(self):
|
def test_deep_nested_allowed(self):
|
||||||
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||||
|
|
||||||
def test_trailing_slash_normalised(self):
|
def test_trailing_slash_normalised(self):
|
||||||
assert _resolve_remote("src/") == "/home/user/src"
|
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||||
|
|
||||||
def test_double_dots_within_sandbox_ok(self):
|
def test_double_dots_within_sandbox_ok(self):
|
||||||
"""Path that resolves back within /home/user is allowed."""
|
"""Path that resolves back within /home/user is allowed."""
|
||||||
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
|
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -0,0 +1,531 @@
|
|||||||
|
"""End-to-end compaction flow test.
|
||||||
|
|
||||||
|
Simulates the full service.py compaction lifecycle using real-format
|
||||||
|
JSONL session files — no SDK subprocess needed. Exercises:
|
||||||
|
|
||||||
|
1. TranscriptBuilder loads a "downloaded" transcript
|
||||||
|
2. User query appended, assistant response streamed
|
||||||
|
3. PreCompact hook fires → CompactionTracker.on_compact()
|
||||||
|
4. Next message → emit_start_if_ready() yields spinner events
|
||||||
|
5. Message after that → emit_end_if_ready() returns CompactionResult
|
||||||
|
6. read_compacted_entries() reads the CLI session file
|
||||||
|
7. TranscriptBuilder.replace_entries() syncs state
|
||||||
|
8. More messages appended post-compaction
|
||||||
|
9. to_jsonl() exports full state for upload
|
||||||
|
10. Fresh builder loads the export — roundtrip verified
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.copilot.response_model import (
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from backend.copilot.sdk.compaction import CompactionTracker
|
||||||
|
from backend.copilot.sdk.transcript import (
|
||||||
|
read_compacted_entries,
|
||||||
|
strip_progress_entries,
|
||||||
|
)
|
||||||
|
from backend.copilot.sdk.transcript_builder import TranscriptBuilder
|
||||||
|
from backend.util import json
|
||||||
|
|
||||||
|
|
||||||
|
def _make_jsonl(*entries: dict) -> str:
|
||||||
|
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
"""Run an async coroutine synchronously."""
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures: realistic CLI session file content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Pre-compaction conversation
|
||||||
|
USER_1 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "What files are in this project?"},
|
||||||
|
}
|
||||||
|
ASST_1_THINKING = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1-think",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_aaa",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "thinking", "thinking": "Let me look at the files..."}],
|
||||||
|
"stop_reason": None,
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ASST_1_TOOL = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1-tool",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_aaa",
|
||||||
|
"type": "message",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "tu1",
|
||||||
|
"name": "Bash",
|
||||||
|
"input": {"command": "ls"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stop_reason": "tool_use",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
TOOL_RESULT_1 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "tr1",
|
||||||
|
"parentUuid": "a1-tool",
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "tu1",
|
||||||
|
"content": "file1.py\nfile2.py",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ASST_1_TEXT = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1-text",
|
||||||
|
"parentUuid": "tr1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_bbb",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "I found file1.py and file2.py."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# Progress entries (should be stripped during upload)
|
||||||
|
PROGRESS_1 = {
|
||||||
|
"type": "progress",
|
||||||
|
"uuid": "prog1",
|
||||||
|
"parentUuid": "a1-tool",
|
||||||
|
"data": {"type": "bash_progress", "stdout": "running ls..."},
|
||||||
|
}
|
||||||
|
# Second user message
|
||||||
|
USER_2 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u2",
|
||||||
|
"parentUuid": "a1-text",
|
||||||
|
"message": {"role": "user", "content": "Show me file1.py"},
|
||||||
|
}
|
||||||
|
ASST_2 = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a2",
|
||||||
|
"parentUuid": "u2",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_ccc",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "Here is file1.py content..."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Compaction summary (written by CLI after context compaction) ---
|
||||||
|
COMPACT_SUMMARY = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Summary: User asked about project files. Found file1.py and file2.py. "
|
||||||
|
"User then asked to see file1.py."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Post-compaction assistant response
|
||||||
|
POST_COMPACT_ASST = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a3",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_ddd",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "Here is the content of file1.py..."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Post-compaction user follow-up
|
||||||
|
USER_3 = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u3",
|
||||||
|
"parentUuid": "a3",
|
||||||
|
"message": {"role": "user", "content": "Now show file2.py"},
|
||||||
|
}
|
||||||
|
ASST_3 = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a4",
|
||||||
|
"parentUuid": "u3",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"id": "msg_sdk_eee",
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "text", "text": "Here is file2.py..."}],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# E2E test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionE2E:
|
||||||
|
def _write_session_file(self, session_dir, entries):
|
||||||
|
"""Write a CLI session JSONL file."""
|
||||||
|
path = session_dir / "session.jsonl"
|
||||||
|
path.write_text(_make_jsonl(*entries))
|
||||||
|
return path
|
||||||
|
|
||||||
|
def test_full_compaction_lifecycle(self, tmp_path, monkeypatch):
|
||||||
|
"""Simulate the complete service.py compaction flow.
|
||||||
|
|
||||||
|
Timeline:
|
||||||
|
1. Previous turn uploaded transcript with [USER_1, ASST_1, USER_2, ASST_2]
|
||||||
|
2. Current turn: download → load_previous
|
||||||
|
3. User sends "Now show file2.py" → append_user
|
||||||
|
4. SDK starts streaming response
|
||||||
|
5. Mid-stream: PreCompact hook fires (context too large)
|
||||||
|
6. CLI writes compaction summary to session file
|
||||||
|
7. Next SDK message → emit_start (spinner)
|
||||||
|
8. Following message → emit_end (CompactionResult)
|
||||||
|
9. read_compacted_entries reads the session file
|
||||||
|
10. replace_entries syncs TranscriptBuilder
|
||||||
|
11. More assistant messages appended
|
||||||
|
12. Export → upload → next turn downloads it
|
||||||
|
"""
|
||||||
|
# --- Setup CLI projects directory ---
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
# --- Step 1-2: Load "downloaded" transcript from previous turn ---
|
||||||
|
previous_transcript = _make_jsonl(
|
||||||
|
USER_1,
|
||||||
|
ASST_1_THINKING,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
)
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(previous_transcript)
|
||||||
|
assert builder.entry_count == 7
|
||||||
|
|
||||||
|
# --- Step 3: User sends new query ---
|
||||||
|
builder.append_user("Now show file2.py")
|
||||||
|
assert builder.entry_count == 8
|
||||||
|
|
||||||
|
# --- Step 4: SDK starts streaming ---
|
||||||
|
builder.append_assistant(
|
||||||
|
[{"type": "thinking", "thinking": "Let me read file2.py..."}],
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
)
|
||||||
|
assert builder.entry_count == 9
|
||||||
|
|
||||||
|
# --- Step 5-6: PreCompact fires, CLI writes session file ---
|
||||||
|
session_file = self._write_session_file(
|
||||||
|
session_dir,
|
||||||
|
[
|
||||||
|
USER_1,
|
||||||
|
ASST_1_THINKING,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
PROGRESS_1,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
COMPACT_SUMMARY,
|
||||||
|
POST_COMPACT_ASST,
|
||||||
|
USER_3,
|
||||||
|
ASST_3,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Step 7: CompactionTracker receives PreCompact hook ---
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = ChatSession.new(user_id="test-user")
|
||||||
|
tracker.on_compact(str(session_file))
|
||||||
|
|
||||||
|
# --- Step 8: Next SDK message arrives → emit_start ---
|
||||||
|
start_events = tracker.emit_start_if_ready()
|
||||||
|
assert len(start_events) == 3
|
||||||
|
assert isinstance(start_events[0], StreamStartStep)
|
||||||
|
assert isinstance(start_events[1], StreamToolInputStart)
|
||||||
|
assert isinstance(start_events[2], StreamToolInputAvailable)
|
||||||
|
|
||||||
|
# Verify tool_call_id is set
|
||||||
|
tool_call_id = start_events[1].toolCallId
|
||||||
|
assert tool_call_id.startswith("compaction-")
|
||||||
|
|
||||||
|
# --- Step 9: Following message → emit_end ---
|
||||||
|
result = _run(tracker.emit_end_if_ready(session))
|
||||||
|
assert result.just_ended is True
|
||||||
|
assert result.transcript_path == str(session_file)
|
||||||
|
assert len(result.events) == 2
|
||||||
|
assert isinstance(result.events[0], StreamToolOutputAvailable)
|
||||||
|
assert isinstance(result.events[1], StreamFinishStep)
|
||||||
|
# Verify same tool_call_id
|
||||||
|
assert result.events[0].toolCallId == tool_call_id
|
||||||
|
|
||||||
|
# Session should have compaction messages persisted
|
||||||
|
assert len(session.messages) == 2
|
||||||
|
assert session.messages[0].role == "assistant"
|
||||||
|
assert session.messages[1].role == "tool"
|
||||||
|
|
||||||
|
# --- Step 10: read_compacted_entries + replace_entries ---
|
||||||
|
compacted = read_compacted_entries(str(session_file))
|
||||||
|
assert compacted is not None
|
||||||
|
# Should have: COMPACT_SUMMARY + POST_COMPACT_ASST + USER_3 + ASST_3
|
||||||
|
assert len(compacted) == 4
|
||||||
|
assert compacted[0]["uuid"] == "cs1"
|
||||||
|
assert compacted[0]["isCompactSummary"] is True
|
||||||
|
|
||||||
|
# Replace builder state with compacted entries
|
||||||
|
old_count = builder.entry_count
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
assert builder.entry_count == 4 # Only compacted entries
|
||||||
|
assert builder.entry_count < old_count # Compaction reduced entries
|
||||||
|
|
||||||
|
# --- Step 11: More assistant messages after compaction ---
|
||||||
|
builder.append_assistant(
|
||||||
|
[{"type": "text", "text": "Here is file2.py:\n\ndef hello():\n pass"}],
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
stop_reason="end_turn",
|
||||||
|
)
|
||||||
|
assert builder.entry_count == 5
|
||||||
|
|
||||||
|
# --- Step 12: Export for upload ---
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
assert output # Not empty
|
||||||
|
output_entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert len(output_entries) == 5
|
||||||
|
|
||||||
|
# Verify structure:
|
||||||
|
# [COMPACT_SUMMARY, POST_COMPACT_ASST, USER_3, ASST_3, new_assistant]
|
||||||
|
assert output_entries[0]["type"] == "summary"
|
||||||
|
assert output_entries[0].get("isCompactSummary") is True
|
||||||
|
assert output_entries[0]["uuid"] == "cs1"
|
||||||
|
assert output_entries[1]["uuid"] == "a3"
|
||||||
|
assert output_entries[2]["uuid"] == "u3"
|
||||||
|
assert output_entries[3]["uuid"] == "a4"
|
||||||
|
assert output_entries[4]["type"] == "assistant"
|
||||||
|
|
||||||
|
# Verify parent chain is intact
|
||||||
|
assert output_entries[1]["parentUuid"] == "cs1" # a3 → cs1
|
||||||
|
assert output_entries[2]["parentUuid"] == "a3" # u3 → a3
|
||||||
|
assert output_entries[3]["parentUuid"] == "u3" # a4 → u3
|
||||||
|
assert output_entries[4]["parentUuid"] == "a4" # new → a4
|
||||||
|
|
||||||
|
# --- Step 13: Roundtrip — next turn loads this export ---
|
||||||
|
builder2 = TranscriptBuilder()
|
||||||
|
builder2.load_previous(output)
|
||||||
|
assert builder2.entry_count == 5
|
||||||
|
|
||||||
|
# isCompactSummary survives roundtrip
|
||||||
|
output2 = builder2.to_jsonl()
|
||||||
|
first_entry = json.loads(output2.strip().split("\n")[0])
|
||||||
|
assert first_entry.get("isCompactSummary") is True
|
||||||
|
|
||||||
|
# Can append more messages
|
||||||
|
builder2.append_user("What about file3.py?")
|
||||||
|
assert builder2.entry_count == 6
|
||||||
|
final_output = builder2.to_jsonl()
|
||||||
|
last_entry = json.loads(final_output.strip().split("\n")[-1])
|
||||||
|
assert last_entry["type"] == "user"
|
||||||
|
# Parented to the last entry from previous turn
|
||||||
|
assert last_entry["parentUuid"] == output_entries[-1]["uuid"]
|
||||||
|
|
||||||
|
def test_double_compaction_within_session(self, tmp_path, monkeypatch):
|
||||||
|
"""Two compactions in the same session (across reset_for_query)."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
tracker = CompactionTracker()
|
||||||
|
session = ChatSession.new(user_id="test")
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
|
||||||
|
# --- First query with compaction ---
|
||||||
|
builder.append_user("first question")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "first answer"}])
|
||||||
|
|
||||||
|
# Write session file for first compaction
|
||||||
|
first_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs-first",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "First compaction summary"},
|
||||||
|
}
|
||||||
|
first_post = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a-first",
|
||||||
|
"parentUuid": "cs-first",
|
||||||
|
"message": {"role": "assistant", "content": "first post-compact"},
|
||||||
|
}
|
||||||
|
file1 = session_dir / "session1.jsonl"
|
||||||
|
file1.write_text(_make_jsonl(first_summary, first_post))
|
||||||
|
|
||||||
|
tracker.on_compact(str(file1))
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
result1 = _run(tracker.emit_end_if_ready(session))
|
||||||
|
assert result1.just_ended is True
|
||||||
|
|
||||||
|
compacted1 = read_compacted_entries(str(file1))
|
||||||
|
assert compacted1 is not None
|
||||||
|
builder.replace_entries(compacted1)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# --- Reset for second query ---
|
||||||
|
tracker.reset_for_query()
|
||||||
|
|
||||||
|
# --- Second query with compaction ---
|
||||||
|
builder.append_user("second question")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "second answer"}])
|
||||||
|
|
||||||
|
second_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs-second",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "Second compaction summary"},
|
||||||
|
}
|
||||||
|
second_post = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a-second",
|
||||||
|
"parentUuid": "cs-second",
|
||||||
|
"message": {"role": "assistant", "content": "second post-compact"},
|
||||||
|
}
|
||||||
|
file2 = session_dir / "session2.jsonl"
|
||||||
|
file2.write_text(_make_jsonl(second_summary, second_post))
|
||||||
|
|
||||||
|
tracker.on_compact(str(file2))
|
||||||
|
tracker.emit_start_if_ready()
|
||||||
|
result2 = _run(tracker.emit_end_if_ready(session))
|
||||||
|
assert result2.just_ended is True
|
||||||
|
|
||||||
|
compacted2 = read_compacted_entries(str(file2))
|
||||||
|
assert compacted2 is not None
|
||||||
|
builder.replace_entries(compacted2)
|
||||||
|
assert builder.entry_count == 2 # Only second compaction entries
|
||||||
|
|
||||||
|
# Export and verify
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert entries[0]["uuid"] == "cs-second"
|
||||||
|
assert entries[0].get("isCompactSummary") is True
|
||||||
|
|
||||||
|
def test_strip_progress_then_load_then_compact_roundtrip(
|
||||||
|
self, tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Full pipeline: strip → load → compact → replace → export → reload.
|
||||||
|
|
||||||
|
This tests the exact sequence that happens across two turns:
|
||||||
|
Turn 1: SDK produces transcript with progress entries
|
||||||
|
Upload: strip_progress_entries removes progress, upload to cloud
|
||||||
|
Turn 2: Download → load_previous → compaction fires → replace → export
|
||||||
|
Turn 3: Download the Turn 2 export → load_previous (roundtrip)
|
||||||
|
"""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
# --- Turn 1: SDK produces raw transcript ---
|
||||||
|
raw_content = _make_jsonl(
|
||||||
|
USER_1,
|
||||||
|
ASST_1_THINKING,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
PROGRESS_1,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Strip progress for upload
|
||||||
|
stripped = strip_progress_entries(raw_content)
|
||||||
|
stripped_entries = [
|
||||||
|
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||||
|
]
|
||||||
|
# Progress should be gone
|
||||||
|
assert not any(e.get("type") == "progress" for e in stripped_entries)
|
||||||
|
assert len(stripped_entries) == 7 # 8 - 1 progress
|
||||||
|
|
||||||
|
# --- Turn 2: Download stripped, load, compaction happens ---
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(stripped)
|
||||||
|
assert builder.entry_count == 7
|
||||||
|
|
||||||
|
builder.append_user("Now show file2.py")
|
||||||
|
builder.append_assistant(
|
||||||
|
[{"type": "text", "text": "Reading file2.py..."}],
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CLI writes session file with compaction
|
||||||
|
session_file = self._write_session_file(
|
||||||
|
session_dir,
|
||||||
|
[
|
||||||
|
USER_1,
|
||||||
|
ASST_1_TOOL,
|
||||||
|
TOOL_RESULT_1,
|
||||||
|
ASST_1_TEXT,
|
||||||
|
USER_2,
|
||||||
|
ASST_2,
|
||||||
|
COMPACT_SUMMARY,
|
||||||
|
POST_COMPACT_ASST,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
compacted = read_compacted_entries(str(session_file))
|
||||||
|
assert compacted is not None
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
|
||||||
|
# Append post-compaction message
|
||||||
|
builder.append_user("Thanks!")
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
|
||||||
|
# --- Turn 3: Fresh load of Turn 2 export ---
|
||||||
|
builder3 = TranscriptBuilder()
|
||||||
|
builder3.load_previous(output)
|
||||||
|
# Should have: compact_summary + post_compact_asst + "Thanks!"
|
||||||
|
assert builder3.entry_count == 3
|
||||||
|
|
||||||
|
# Compact summary survived the full pipeline
|
||||||
|
first = json.loads(builder3.to_jsonl().strip().split("\n")[0])
|
||||||
|
assert first.get("isCompactSummary") is True
|
||||||
|
assert first["type"] == "summary"
|
||||||
715
autogpt_platform/backend/backend/copilot/sdk/file_ref.py
Normal file
715
autogpt_platform/backend/backend/copilot/sdk/file_ref.py
Normal file
@@ -0,0 +1,715 @@
|
|||||||
|
"""File reference protocol for tool call inputs.
|
||||||
|
|
||||||
|
Allows the LLM to pass a file reference instead of embedding large content
|
||||||
|
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
|
||||||
|
arguments before the tool is executed.
|
||||||
|
|
||||||
|
Protocol
|
||||||
|
--------
|
||||||
|
|
||||||
|
@@agptfile:<uri>[<start>-<end>]
|
||||||
|
|
||||||
|
``<uri>`` (required)
|
||||||
|
- ``workspace://<file_id>`` — workspace file by ID
|
||||||
|
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
|
||||||
|
- ``workspace:///<path>`` — workspace file by virtual path
|
||||||
|
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
|
||||||
|
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
|
||||||
|
- Any absolute path that resolves inside the E2B sandbox
|
||||||
|
(``/home/user/...``) when a sandbox is active
|
||||||
|
|
||||||
|
``[<start>-<end>]`` (optional)
|
||||||
|
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
|
||||||
|
Omit to read the entire file.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
@@agptfile:workspace://abc123
|
||||||
|
@@agptfile:workspace://abc123[10-50]
|
||||||
|
@@agptfile:workspace:///reports/q1.md
|
||||||
|
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||||
|
@@agptfile:/home/user/script.sh
|
||||||
|
"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.copilot.context import (
|
||||||
|
get_current_sandbox,
|
||||||
|
get_sdk_cwd,
|
||||||
|
get_workspace_manager,
|
||||||
|
is_allowed_local_path,
|
||||||
|
resolve_sandbox_path,
|
||||||
|
)
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.util.file import parse_workspace_uri
|
||||||
|
from backend.util.file_content_parser import (
|
||||||
|
BINARY_FORMATS,
|
||||||
|
MIME_TO_FORMAT,
|
||||||
|
PARSE_EXCEPTIONS,
|
||||||
|
infer_format_from_uri,
|
||||||
|
parse_file_content,
|
||||||
|
)
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
|
class FileRefExpansionError(Exception):
|
||||||
|
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
|
||||||
|
|
||||||
|
Separating this from inline substitution lets callers (e.g. the MCP tool
|
||||||
|
wrapper) block tool execution and surface a helpful error to the model
|
||||||
|
rather than passing an ``[file-ref error: …]`` string as actual input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FILE_REF_PREFIX = "@@agptfile:"
|
||||||
|
|
||||||
|
# Matches: @@agptfile:<uri>[start-end]?
|
||||||
|
# Group 1 – URI; must start with '/' (absolute path) or 'workspace://'
|
||||||
|
# Group 2 – start line (optional)
|
||||||
|
# Group 3 – end line (optional)
|
||||||
|
_FILE_REF_RE = re.compile(
|
||||||
|
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Maximum characters returned for a single file reference expansion.
|
||||||
|
_MAX_EXPAND_CHARS = 200_000
|
||||||
|
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||||
|
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||||
|
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||||
|
_MAX_BARE_REF_BYTES = 10_000_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileRef:
|
||||||
|
uri: str
|
||||||
|
start_line: int | None # 1-indexed, inclusive
|
||||||
|
end_line: int | None # 1-indexed, inclusive
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API (top-down: main functions first, helpers below)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def parse_file_ref(text: str) -> FileRef | None:
|
||||||
|
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||||
|
|
||||||
|
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
|
||||||
|
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
|
||||||
|
expand references embedded in larger strings.
|
||||||
|
"""
|
||||||
|
m = _FILE_REF_RE.fullmatch(text.strip())
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
start = int(m.group(2)) if m.group(2) else None
|
||||||
|
end = int(m.group(3)) if m.group(3) else None
|
||||||
|
if start is not None and start < 1:
|
||||||
|
return None
|
||||||
|
if end is not None and end < 1:
|
||||||
|
return None
|
||||||
|
if start is not None and end is not None and end < start:
|
||||||
|
return None
|
||||||
|
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||||
|
|
||||||
|
|
||||||
|
async def read_file_bytes(
|
||||||
|
uri: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
) -> bytes:
|
||||||
|
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
|
||||||
|
|
||||||
|
Raises :class:`ValueError` if the URI cannot be resolved.
|
||||||
|
"""
|
||||||
|
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
|
||||||
|
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
|
||||||
|
|
||||||
|
if plain.startswith("workspace://"):
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("workspace:// file references require authentication")
|
||||||
|
manager = await get_workspace_manager(user_id, session.session_id)
|
||||||
|
ws = parse_workspace_uri(plain)
|
||||||
|
try:
|
||||||
|
data = await (
|
||||||
|
manager.read_file(ws.file_ref)
|
||||||
|
if ws.is_path
|
||||||
|
else manager.read_file_by_id(ws.file_ref)
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ValueError(f"File not found: {plain}")
|
||||||
|
except (PermissionError, OSError) as exc:
|
||||||
|
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||||
|
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||||
|
# AttributeError/TypeError: workspace manager returned an
|
||||||
|
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||||
|
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||||
|
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||||
|
# NOTE: Workspace API does not support pre-read size checks;
|
||||||
|
# the full file is loaded before the size guard below.
|
||||||
|
if len(data) > _MAX_BARE_REF_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||||
|
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||||
|
try:
|
||||||
|
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||||
|
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||||
|
with open(resolved, "rb") as fh:
|
||||||
|
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||||
|
if len(data) > _MAX_BARE_REF_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||||
|
f"limit {_MAX_BARE_REF_BYTES})"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ValueError(f"File not found: {plain}")
|
||||||
|
except OSError as exc:
|
||||||
|
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||||
|
|
||||||
|
sandbox = get_current_sandbox()
|
||||||
|
if sandbox is not None:
|
||||||
|
try:
|
||||||
|
remote = resolve_sandbox_path(plain)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||||
|
) from exc
|
||||||
|
try:
|
||||||
|
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||||
|
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||||
|
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||||
|
# inherit from standard exceptions. Import lazily to avoid a
|
||||||
|
# hard dependency on e2b at module level.
|
||||||
|
try:
|
||||||
|
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||||
|
|
||||||
|
if isinstance(exc, SandboxException):
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to read from sandbox: {plain}: {exc}"
|
||||||
|
) from exc
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||||
|
# so they surface as real bugs rather than being silently masked.
|
||||||
|
raise
|
||||||
|
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||||
|
# the full file is loaded before the size guard below.
|
||||||
|
if len(data) > _MAX_BARE_REF_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_file_ref(
|
||||||
|
ref: FileRef,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
) -> str:
|
||||||
|
"""Resolve a :class:`FileRef` to its text content."""
|
||||||
|
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||||
|
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||||
|
|
||||||
|
|
||||||
|
async def expand_file_refs_in_string(
|
||||||
|
text: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
*,
|
||||||
|
raise_on_error: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
|
||||||
|
|
||||||
|
Non-reference text is passed through unchanged.
|
||||||
|
|
||||||
|
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
|
||||||
|
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
|
||||||
|
where partial expansion is acceptable.
|
||||||
|
|
||||||
|
If *raise_on_error* is ``True``, any resolution failure raises
|
||||||
|
:class:`FileRefExpansionError` immediately so the caller can block the
|
||||||
|
operation and surface a clean error to the model.
|
||||||
|
"""
|
||||||
|
if FILE_REF_PREFIX not in text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
result: list[str] = []
|
||||||
|
last_end = 0
|
||||||
|
total_chars = 0
|
||||||
|
for m in _FILE_REF_RE.finditer(text):
|
||||||
|
result.append(text[last_end : m.start()])
|
||||||
|
start = int(m.group(2)) if m.group(2) else None
|
||||||
|
end = int(m.group(3)) if m.group(3) else None
|
||||||
|
if (start is not None and start < 1) or (end is not None and end < 1):
|
||||||
|
msg = f"line numbers must be >= 1: {m.group(0)}"
|
||||||
|
if raise_on_error:
|
||||||
|
raise FileRefExpansionError(msg)
|
||||||
|
result.append(f"[file-ref error: {msg}]")
|
||||||
|
last_end = m.end()
|
||||||
|
continue
|
||||||
|
if start is not None and end is not None and end < start:
|
||||||
|
msg = f"end line must be >= start line: {m.group(0)}"
|
||||||
|
if raise_on_error:
|
||||||
|
raise FileRefExpansionError(msg)
|
||||||
|
result.append(f"[file-ref error: {msg}]")
|
||||||
|
last_end = m.end()
|
||||||
|
continue
|
||||||
|
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||||
|
try:
|
||||||
|
content = await resolve_file_ref(ref, user_id, session)
|
||||||
|
if len(content) > _MAX_EXPAND_CHARS:
|
||||||
|
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||||
|
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||||
|
# remaining == 0 means the budget was exactly exhausted by the
|
||||||
|
# previous ref. The elif below (len > remaining) won't catch
|
||||||
|
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||||
|
if remaining <= 0:
|
||||||
|
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||||
|
elif len(content) > remaining:
|
||||||
|
content = content[:remaining] + "\n... [total budget exhausted]"
|
||||||
|
total_chars += len(content)
|
||||||
|
result.append(content)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
|
||||||
|
if raise_on_error:
|
||||||
|
raise FileRefExpansionError(str(exc)) from exc
|
||||||
|
result.append(f"[file-ref error: {exc}]")
|
||||||
|
last_end = m.end()
|
||||||
|
|
||||||
|
result.append(text[last_end:])
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def expand_file_refs_in_args(
|
||||||
|
args: dict[str, Any],
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
*,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||||
|
|
||||||
|
String values are expanded in-place. Nested dicts and lists are
|
||||||
|
traversed. Non-string scalars are returned unchanged.
|
||||||
|
|
||||||
|
**Bare references** (the entire argument value is a single
|
||||||
|
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||||
|
parsed according to the file's extension or MIME type. See
|
||||||
|
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||||
|
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||||
|
|
||||||
|
When *input_schema* is provided and the target property has
|
||||||
|
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||||
|
is returned as a plain string so blocks receive the original text.
|
||||||
|
|
||||||
|
If the format is unrecognised or parsing fails, the content is returned as
|
||||||
|
a plain string (the fallback).
|
||||||
|
|
||||||
|
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||||
|
produce a plain string — structured parsing only applies to bare refs.
|
||||||
|
|
||||||
|
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||||
|
so the tool is *not* executed with an error string as its input. The
|
||||||
|
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||||
|
response that lets the model correct the reference before retrying.
|
||||||
|
"""
|
||||||
|
if not args:
|
||||||
|
return args
|
||||||
|
|
||||||
|
properties = (input_schema or {}).get("properties", {})
|
||||||
|
|
||||||
|
async def _expand(
|
||||||
|
value: Any,
|
||||||
|
*,
|
||||||
|
prop_schema: dict[str, Any] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Recursively expand a single argument value.
|
||||||
|
|
||||||
|
Strings are checked for ``@@agptfile:`` references and expanded
|
||||||
|
(bare refs get structured parsing; embedded refs get inline
|
||||||
|
substitution). Dicts and lists are traversed recursively,
|
||||||
|
threading the corresponding sub-schema from *prop_schema* so
|
||||||
|
that nested fields also receive correct type-aware expansion.
|
||||||
|
Non-string scalars pass through unchanged.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
ref = parse_file_ref(value)
|
||||||
|
if ref is not None:
|
||||||
|
# MediaFileType fields: return the raw URI immediately —
|
||||||
|
# no file reading, no format inference, no content parsing.
|
||||||
|
if _is_media_file_field(prop_schema):
|
||||||
|
return ref.uri
|
||||||
|
|
||||||
|
fmt = infer_format_from_uri(ref.uri)
|
||||||
|
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||||
|
# When the MIME fragment is also missing, fall back to the
|
||||||
|
# workspace file manager's metadata for format detection.
|
||||||
|
if fmt is None and ref.uri.startswith("workspace://"):
|
||||||
|
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||||
|
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||||
|
|
||||||
|
# Not a bare ref — do normal inline expansion.
|
||||||
|
return await expand_file_refs_in_string(
|
||||||
|
value, user_id, session, raise_on_error=True
|
||||||
|
)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# When the schema says this is an object but doesn't define
|
||||||
|
# inner properties, skip expansion — the caller (e.g.
|
||||||
|
# RunBlockTool) will expand with the actual nested schema.
|
||||||
|
if (
|
||||||
|
prop_schema is not None
|
||||||
|
and prop_schema.get("type") == "object"
|
||||||
|
and "properties" not in prop_schema
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
nested_props = (prop_schema or {}).get("properties", {})
|
||||||
|
return {
|
||||||
|
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||||
|
for k, v in value.items()
|
||||||
|
}
|
||||||
|
if isinstance(value, list):
|
||||||
|
items_schema = (prop_schema or {}).get("items")
|
||||||
|
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Private helpers (used by the public functions above)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||||
|
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||||
|
|
||||||
|
When the requested range extends beyond the file, a note is appended
|
||||||
|
so the LLM knows it received the entire remaining content.
|
||||||
|
"""
|
||||||
|
if start is None and end is None:
|
||||||
|
return text
|
||||||
|
lines = text.splitlines(keepends=True)
|
||||||
|
total = len(lines)
|
||||||
|
s = (start - 1) if start is not None else 0
|
||||||
|
e = end if end is not None else total
|
||||||
|
selected = list(itertools.islice(lines, s, e))
|
||||||
|
result = "".join(selected)
|
||||||
|
if end is not None and end > total:
|
||||||
|
result += f"\n[Note: file has only {total} lines]\n"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _to_str(content: str | bytes) -> str:
|
||||||
|
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
return content.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_content_size(content: str | bytes) -> None:
|
||||||
|
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||||
|
|
||||||
|
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||||
|
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||||
|
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||||
|
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||||
|
|
||||||
|
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||||
|
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||||
|
mean the byte size can be up to 4x the character count.
|
||||||
|
"""
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
size = len(content)
|
||||||
|
else:
|
||||||
|
char_len = len(content)
|
||||||
|
# Fast lower bound: UTF-8 byte count >= char count.
|
||||||
|
# If char count already exceeds the limit, reject immediately
|
||||||
|
# without allocating an encoded copy.
|
||||||
|
if char_len > _MAX_BARE_REF_BYTES:
|
||||||
|
size = char_len # real byte size is even larger
|
||||||
|
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||||
|
# If worst-case is still under the limit, skip encoding entirely.
|
||||||
|
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Edge case: char count is under limit but multibyte chars
|
||||||
|
# might push byte count over. Encode to get exact size.
|
||||||
|
size = len(content.encode("utf-8"))
|
||||||
|
if size > _MAX_BARE_REF_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large for structured parsing "
|
||||||
|
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _infer_format_from_workspace(
|
||||||
|
uri: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
) -> str | None:
|
||||||
|
"""Look up workspace file metadata to infer the format.
|
||||||
|
|
||||||
|
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||||
|
When the MIME fragment is also absent, we query the workspace file
|
||||||
|
manager for the file's stored MIME type and original filename.
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
ws = parse_workspace_uri(uri)
|
||||||
|
manager = await get_workspace_manager(user_id, session.session_id)
|
||||||
|
info = await (
|
||||||
|
manager.get_file_info(ws.file_ref)
|
||||||
|
if not ws.is_path
|
||||||
|
else manager.get_file_info_by_path(ws.file_ref)
|
||||||
|
)
|
||||||
|
if info is None:
|
||||||
|
return None
|
||||||
|
# Try MIME type first, then filename extension.
|
||||||
|
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||||
|
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||||
|
except (
|
||||||
|
ValueError,
|
||||||
|
FileNotFoundError,
|
||||||
|
OSError,
|
||||||
|
PermissionError,
|
||||||
|
AttributeError,
|
||||||
|
TypeError,
|
||||||
|
):
|
||||||
|
# Expected failures: bad URI, missing file, permission denied, or
|
||||||
|
# workspace manager returning unexpected types. Propagate anything
|
||||||
|
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||||
|
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||||
|
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||||
|
if prop_schema is None:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
prop_schema.get("type") == "string"
|
||||||
|
and prop_schema.get("format") == MediaFileType.string_format
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _expand_bare_ref(
|
||||||
|
ref: FileRef,
|
||||||
|
fmt: str | None,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
prop_schema: dict[str, Any] | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||||
|
|
||||||
|
This is the structured-parsing path: the file is read, optionally parsed
|
||||||
|
according to *fmt*, and adapted to the target *prop_schema*.
|
||||||
|
|
||||||
|
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||||
|
|
||||||
|
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||||
|
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if fmt is not None and fmt in BINARY_FORMATS:
|
||||||
|
# Binary formats need raw bytes, not UTF-8 text.
|
||||||
|
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||||
|
# — ignore them and parse full bytes. Warn so the caller/model
|
||||||
|
# knows the range was silently dropped.
|
||||||
|
if ref.start_line is not None or ref.end_line is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||||
|
"binary formats are always parsed in full.",
|
||||||
|
ref.start_line,
|
||||||
|
ref.end_line,
|
||||||
|
fmt,
|
||||||
|
ref.uri,
|
||||||
|
)
|
||||||
|
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||||
|
else:
|
||||||
|
content = await resolve_file_ref(ref, user_id, session)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise FileRefExpansionError(str(exc)) from exc
|
||||||
|
|
||||||
|
# For known formats this rejects files >10 MB before parsing.
|
||||||
|
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||||
|
# but this check still guards the parsing path which has no char limit.
|
||||||
|
# _check_content_size raises ValueError, which we unify here just like
|
||||||
|
# resolution errors above.
|
||||||
|
try:
|
||||||
|
_check_content_size(content)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise FileRefExpansionError(str(exc)) from exc
|
||||||
|
|
||||||
|
# When the schema declares this parameter as "string",
|
||||||
|
# return raw file content — don't parse into a structured
|
||||||
|
# type that would need json.dumps() serialisation.
|
||||||
|
expect_string = (prop_schema or {}).get("type") == "string"
|
||||||
|
if expect_string:
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
raise FileRefExpansionError(
|
||||||
|
f"Cannot use {fmt} file as text input: "
|
||||||
|
f"binary formats (parquet, xlsx) must be passed "
|
||||||
|
f"to a block that accepts structured data (list/object), "
|
||||||
|
f"not a string-typed parameter."
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
|
||||||
|
if fmt is not None:
|
||||||
|
# Use strict mode for binary formats so we surface the
|
||||||
|
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||||
|
# file) instead of silently returning garbled bytes.
|
||||||
|
strict = fmt in BINARY_FORMATS
|
||||||
|
try:
|
||||||
|
parsed = parse_file_content(content, fmt, strict=strict)
|
||||||
|
except PARSE_EXCEPTIONS as exc:
|
||||||
|
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||||
|
# Normalize bytes fallback to str so tools never
|
||||||
|
# receive raw bytes when parsing fails.
|
||||||
|
if isinstance(parsed, bytes):
|
||||||
|
parsed = _to_str(parsed)
|
||||||
|
return _adapt_to_schema(parsed, prop_schema)
|
||||||
|
|
||||||
|
# Unknown format — return as plain string, but apply
|
||||||
|
# the same per-ref character limit used by inline refs
|
||||||
|
# to prevent injecting unexpectedly large content.
|
||||||
|
text = _to_str(content)
|
||||||
|
if len(text) > _MAX_EXPAND_CHARS:
|
||||||
|
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||||
|
"""Adapt a parsed file value to better fit the target schema type.
|
||||||
|
|
||||||
|
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||||
|
that doesn't match the block's expected type, this function converts it to
|
||||||
|
a more useful representation instead of relying on pydantic's generic
|
||||||
|
coercion (which can produce awkward results like flattened dicts → lists).
|
||||||
|
|
||||||
|
Returns *parsed* unchanged when no adaptation is needed.
|
||||||
|
"""
|
||||||
|
if prop_schema is None:
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
target_type = prop_schema.get("type")
|
||||||
|
|
||||||
|
# Dict → array: delegate to helper.
|
||||||
|
if isinstance(parsed, dict) and target_type == "array":
|
||||||
|
return _adapt_dict_to_array(parsed, prop_schema)
|
||||||
|
|
||||||
|
# List → object: delegate to helper (raises for non-tabular lists).
|
||||||
|
if isinstance(parsed, list) and target_type == "object":
|
||||||
|
return _adapt_list_to_object(parsed)
|
||||||
|
|
||||||
|
# Tabular list → Any (no type): convert to list of dicts.
|
||||||
|
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||||
|
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||||
|
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||||
|
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||||
|
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||||
|
return _tabular_to_list_of_dicts(parsed)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||||
|
"""Adapt a parsed dict to an array-typed field.
|
||||||
|
|
||||||
|
Extracts list-valued entries when the target item type is ``array``,
|
||||||
|
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||||
|
or wraps in ``[parsed]`` as a fallback.
|
||||||
|
"""
|
||||||
|
items_type = (prop_schema.get("items") or {}).get("type")
|
||||||
|
if items_type == "array":
|
||||||
|
# Target is List[List[Any]] — extract list-typed values from the
|
||||||
|
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||||
|
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||||
|
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||||
|
if list_values:
|
||||||
|
return list_values
|
||||||
|
if items_type == "string":
|
||||||
|
# Target is List[str] — wrapping a dict would give [dict]
|
||||||
|
# which can't coerce to strings. Return unchanged and let
|
||||||
|
# pydantic surface a clear validation error.
|
||||||
|
return parsed
|
||||||
|
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||||
|
# instead of pydantic flattening keys/values into a flat list.
|
||||||
|
return [parsed]
|
||||||
|
|
||||||
|
|
||||||
|
def _adapt_list_to_object(parsed: list) -> Any:
|
||||||
|
"""Adapt a parsed list to an object-typed field.
|
||||||
|
|
||||||
|
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||||
|
"""
|
||||||
|
if _is_tabular(parsed):
|
||||||
|
return _tabular_to_column_dict(parsed)
|
||||||
|
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||||
|
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||||
|
# get a clear error rather than pydantic silently wrapping the list.
|
||||||
|
raise FileRefExpansionError(
|
||||||
|
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||||
|
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tabular(parsed: Any) -> bool:
|
||||||
|
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||||
|
|
||||||
|
Uses isinstance checks because this is a structural type guard on
|
||||||
|
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||||
|
help here — we need to verify exact list-of-lists shape.
|
||||||
|
"""
|
||||||
|
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||||
|
return False
|
||||||
|
header = parsed[0]
|
||||||
|
if not isinstance(header, list) or not header:
|
||||||
|
return False
|
||||||
|
if not all(isinstance(h, str) for h in header):
|
||||||
|
return False
|
||||||
|
return all(isinstance(row, list) for row in parsed[1:])
|
||||||
|
|
||||||
|
|
||||||
|
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||||
|
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||||
|
|
||||||
|
Ragged rows (fewer columns than the header) get None for missing values.
|
||||||
|
Extra values beyond the header length are silently dropped.
|
||||||
|
"""
|
||||||
|
header = parsed[0]
|
||||||
|
return [
|
||||||
|
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||||
|
for row in parsed[1:]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||||
|
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||||
|
|
||||||
|
Ragged rows (fewer columns than the header) get None for missing values,
|
||||||
|
ensuring all columns have equal length.
|
||||||
|
"""
|
||||||
|
header = parsed[0]
|
||||||
|
return {
|
||||||
|
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||||
|
for i, col in enumerate(header)
|
||||||
|
}
|
||||||
@@ -0,0 +1,521 @@
|
|||||||
|
"""Integration tests for @@agptfile: reference expansion in tool calls.
|
||||||
|
|
||||||
|
These tests verify the end-to-end behaviour of the file reference protocol:
|
||||||
|
- Parsing @@agptfile: tokens from tool arguments
|
||||||
|
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
|
||||||
|
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
|
||||||
|
- The extended Read tool handler (workspace:// pass-through via session context)
|
||||||
|
|
||||||
|
No real LLM or database is required; workspace reads are stubbed where needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.copilot.sdk.file_ref import (
|
||||||
|
FileRef,
|
||||||
|
expand_file_refs_in_args,
|
||||||
|
expand_file_refs_in_string,
|
||||||
|
read_file_bytes,
|
||||||
|
resolve_file_ref,
|
||||||
|
)
|
||||||
|
from backend.copilot.sdk.tool_adapter import _read_file_handler
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(session_id: str = "integ-sess") -> MagicMock:
|
||||||
|
s = MagicMock()
|
||||||
|
s.session_id = session_id
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Local-file resolution (sdk_cwd)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_file_ref_local_path():
|
||||||
|
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
# Write a test file inside sdk_cwd
|
||||||
|
test_file = os.path.join(sdk_cwd, "hello.txt")
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("line1\nline2\nline3\n")
|
||||||
|
|
||||||
|
session = _make_session()
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
ref = FileRef(uri=test_file, start_line=None, end_line=None)
|
||||||
|
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||||
|
|
||||||
|
assert content == "line1\nline2\nline3\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_file_ref_local_path_with_line_range():
|
||||||
|
"""resolve_file_ref respects line ranges for local files."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
test_file = os.path.join(sdk_cwd, "multi.txt")
|
||||||
|
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.writelines(lines)
|
||||||
|
|
||||||
|
session = _make_session()
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
ref = FileRef(uri=test_file, start_line=3, end_line=5)
|
||||||
|
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||||
|
|
||||||
|
assert content == "line3\nline4\nline5\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
|
||||||
|
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||||
|
"backend.copilot.context._current_sandbox"
|
||||||
|
) as mock_sandbox_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
mock_sandbox_var.get.return_value = None
|
||||||
|
|
||||||
|
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
|
||||||
|
with pytest.raises(ValueError, match="not allowed"):
|
||||||
|
await resolve_file_ref(ref, user_id="u1", session=_make_session())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# expand_file_refs_in_string — integration with real files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expand_string_with_real_file():
|
||||||
|
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
test_file = os.path.join(sdk_cwd, "data.txt")
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("hello world\n")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_string(
|
||||||
|
f"Content: @@agptfile:{test_file}",
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Content: hello world\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expand_string_missing_file_is_surfaced_inline():
|
||||||
|
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_string(
|
||||||
|
f"@@agptfile:{missing}",
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "[file-ref error:" in result
|
||||||
|
assert "not found" in result.lower() or "not allowed" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# expand_file_refs_in_args — dict traversal with real files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||||
|
"""Nested @@agptfile: references in args are fully expanded."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
file_a = os.path.join(sdk_cwd, "a.txt")
|
||||||
|
file_b = os.path.join(sdk_cwd, "b.txt")
|
||||||
|
with open(file_a, "w") as f:
|
||||||
|
f.write("AAA")
|
||||||
|
with open(file_b, "w") as f:
|
||||||
|
f.write("BBB")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{
|
||||||
|
"outer": {
|
||||||
|
"content_a": f"@@agptfile:{file_a}",
|
||||||
|
"content_b": f"start @@agptfile:{file_b} end",
|
||||||
|
},
|
||||||
|
"count": 42,
|
||||||
|
},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["outer"]["content_a"] == "AAA"
|
||||||
|
assert result["outer"]["content_b"] == "start BBB end"
|
||||||
|
assert result["count"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# expand_file_refs_in_args — bare ref structured parsing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_json_returns_parsed_dict():
|
||||||
|
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
json_file = os.path.join(sdk_cwd, "data.json")
|
||||||
|
with open(json_file, "w") as f:
|
||||||
|
f.write('{"key": "value", "count": 42}')
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"data": f"@@agptfile:{json_file}"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["data"] == {"key": "value", "count": 42}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_csv_returns_parsed_table():
|
||||||
|
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||||
|
with open(csv_file, "w") as f:
|
||||||
|
f.write("Name,Score\nAlice,90\nBob,85")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"input": f"@@agptfile:{csv_file}"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["input"] == [
|
||||||
|
["Name", "Score"],
|
||||||
|
["Alice", "90"],
|
||||||
|
["Bob", "85"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_unknown_extension_returns_string():
|
||||||
|
"""Bare ref to a file with unknown extension returns plain string."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||||
|
with open(txt_file, "w") as f:
|
||||||
|
f.write("plain text content")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"data": f"@@agptfile:{txt_file}"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["data"] == "plain text content"
|
||||||
|
assert isinstance(result["data"], str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||||
|
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||||
|
with open(json_file, "w") as f:
|
||||||
|
f.write("not valid json {{{")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"data": f"@@agptfile:{json_file}"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["data"] == "not valid json {{{"
|
||||||
|
assert isinstance(result["data"], str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||||
|
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
json_file = os.path.join(sdk_cwd, "data.json")
|
||||||
|
with open(json_file, "w") as f:
|
||||||
|
f.write('{"key": "value"}')
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result["data"], str)
|
||||||
|
assert result["data"].startswith("prefix ")
|
||||||
|
assert result["data"].endswith(" suffix")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||||
|
"""Bare ref to a .yaml file returns parsed dict."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||||
|
with open(yaml_file, "w") as f:
|
||||||
|
f.write("name: test\ncount: 42\n")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"config": f"@@agptfile:{yaml_file}"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["config"] == {"name": "test", "count": 42}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||||
|
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||||
|
|
||||||
|
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||||
|
parsed and the range is silently dropped with a log warning.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("pandas not installed")
|
||||||
|
try:
|
||||||
|
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("pyarrow not installed")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
df.to_parquet(buf, index=False)
|
||||||
|
with open(parquet_file, "wb") as f:
|
||||||
|
f.write(buf.getvalue())
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
# Line range [1-2] should be silently ignored for binary formats.
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Full file is returned despite the line range.
|
||||||
|
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bare_ref_toml_returns_parsed_dict():
|
||||||
|
"""Bare ref to a .toml file returns parsed dict."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||||
|
with open(toml_file, "w") as f:
|
||||||
|
f.write('name = "test"\ncount = 42\n')
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
|
||||||
|
result = await expand_file_refs_in_args(
|
||||||
|
{"config": f"@@agptfile:{toml_file}"},
|
||||||
|
user_id="u1",
|
||||||
|
session=_make_session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["config"] == {"name": "test", "count": 42}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _read_file_handler — extended to accept workspace:// and local paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_handler_local_file():
|
||||||
|
"""_read_file_handler reads a local file when it's within sdk_cwd."""
|
||||||
|
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||||
|
test_file = os.path.join(sdk_cwd, "read_test.txt")
|
||||||
|
lines = [f"L{i}\n" for i in range(1, 6)]
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.writelines(lines)
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||||
|
"backend.copilot.context._current_project_dir"
|
||||||
|
) as mock_proj_var, patch(
|
||||||
|
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||||
|
return_value=("user-1", _make_session()),
|
||||||
|
):
|
||||||
|
mock_cwd_var.get.return_value = sdk_cwd
|
||||||
|
mock_proj_var.get.return_value = ""
|
||||||
|
|
||||||
|
result = await _read_file_handler(
|
||||||
|
{"file_path": test_file, "offset": 0, "limit": 5}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result["isError"]
|
||||||
|
text = result["content"][0]["text"]
|
||||||
|
assert "L1" in text
|
||||||
|
assert "L5" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_handler_workspace_uri():
|
||||||
|
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
|
||||||
|
mock_session = _make_session()
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||||
|
return_value=("user-1", mock_session),
|
||||||
|
), patch(
|
||||||
|
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||||
|
new=AsyncMock(return_value=mock_manager),
|
||||||
|
):
|
||||||
|
result = await _read_file_handler(
|
||||||
|
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result["isError"], result["content"][0]["text"]
|
||||||
|
text = result["content"][0]["text"]
|
||||||
|
assert "workspace file content" in text
|
||||||
|
assert "line two" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_handler_workspace_uri_no_session():
|
||||||
|
"""_read_file_handler returns error when workspace:// is used without session."""
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||||
|
return_value=(None, None),
|
||||||
|
):
|
||||||
|
result = await _read_file_handler({"file_path": "workspace://some-id"})
|
||||||
|
|
||||||
|
assert result["isError"]
|
||||||
|
assert "session" in result["content"][0]["text"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_handler_access_denied():
|
||||||
|
"""_read_file_handler rejects paths outside allowed locations."""
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||||
|
"backend.copilot.context._current_sandbox"
|
||||||
|
) as mock_sandbox, patch(
|
||||||
|
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||||
|
return_value=("user-1", _make_session()),
|
||||||
|
):
|
||||||
|
mock_cwd.get.return_value = "/tmp/safe-dir"
|
||||||
|
mock_sandbox.get.return_value = None
|
||||||
|
|
||||||
|
result = await _read_file_handler({"file_path": "/etc/passwd"})
|
||||||
|
|
||||||
|
assert result["isError"]
|
||||||
|
assert "not allowed" in result["content"][0]["text"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_bytes_workspace_virtual_path():
|
||||||
|
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
|
||||||
|
session = _make_session()
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.read_file.return_value = b"virtual path content"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||||
|
new=AsyncMock(return_value=mock_manager),
|
||||||
|
):
|
||||||
|
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||||
|
|
||||||
|
assert result == b"virtual path content"
|
||||||
|
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_bytes_e2b_sandbox_branch():
|
||||||
|
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
|
||||||
|
session = _make_session()
|
||||||
|
mock_sandbox = AsyncMock()
|
||||||
|
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||||
|
"backend.copilot.context._current_sandbox"
|
||||||
|
) as mock_sandbox_var, patch(
|
||||||
|
"backend.copilot.context._current_project_dir"
|
||||||
|
) as mock_proj:
|
||||||
|
mock_cwd.get.return_value = ""
|
||||||
|
mock_sandbox_var.get.return_value = mock_sandbox
|
||||||
|
mock_proj.get.return_value = ""
|
||||||
|
|
||||||
|
result = await read_file_bytes("/home/user/script.sh", None, session)
|
||||||
|
|
||||||
|
assert result == b"sandbox content"
|
||||||
|
mock_sandbox.files.read.assert_awaited_once_with(
|
||||||
|
"/home/user/script.sh", format="bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
|
||||||
|
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
|
||||||
|
session = _make_session()
|
||||||
|
mock_sandbox = AsyncMock()
|
||||||
|
|
||||||
|
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||||
|
"backend.copilot.context._current_sandbox"
|
||||||
|
) as mock_sandbox_var, patch(
|
||||||
|
"backend.copilot.context._current_project_dir"
|
||||||
|
) as mock_proj:
|
||||||
|
mock_cwd.get.return_value = ""
|
||||||
|
mock_sandbox_var.get.return_value = mock_sandbox
|
||||||
|
mock_proj.get.return_value = ""
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not allowed"):
|
||||||
|
await read_file_bytes("/etc/passwd", None, session)
|
||||||
1979
autogpt_platform/backend/backend/copilot/sdk/file_ref_test.py
Normal file
1979
autogpt_platform/backend/backend/copilot/sdk/file_ref_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,59 @@
|
|||||||
|
## MCP Tool Guide
|
||||||
|
|
||||||
|
### Workflow
|
||||||
|
|
||||||
|
`run_mcp_tool` follows a two-step pattern:
|
||||||
|
|
||||||
|
1. **Discover** — call with only `server_url` to list available tools on the server.
|
||||||
|
2. **Execute** — call again with `server_url`, `tool_name`, and `tool_arguments` to run a tool.
|
||||||
|
|
||||||
|
### Known hosted MCP servers
|
||||||
|
|
||||||
|
Use these URLs directly without asking the user:
|
||||||
|
|
||||||
|
| Service | URL |
|
||||||
|
|---|---|
|
||||||
|
| Notion | `https://mcp.notion.com/mcp` |
|
||||||
|
| Linear | `https://mcp.linear.app/mcp` |
|
||||||
|
| Stripe | `https://mcp.stripe.com` |
|
||||||
|
| Intercom | `https://mcp.intercom.com/mcp` |
|
||||||
|
| Cloudflare | `https://mcp.cloudflare.com/mcp` |
|
||||||
|
| Atlassian / Jira | `https://mcp.atlassian.com/mcp` |
|
||||||
|
|
||||||
|
For other services, search the MCP registry API:
|
||||||
|
```http
|
||||||
|
GET https://registry.modelcontextprotocol.io/v0/servers?q=<search_term>
|
||||||
|
```
|
||||||
|
Each result includes a `remotes` array with the exact server URL to use.
|
||||||
|
|
||||||
|
### Important: Check blocks first
|
||||||
|
|
||||||
|
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||||
|
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||||
|
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||||
|
|
||||||
|
Only use `run_mcp_tool` when:
|
||||||
|
- The service is in the known hosted MCP servers list above, OR
|
||||||
|
- You searched `find_block` first and found no matching blocks
|
||||||
|
|
||||||
|
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||||
|
or from the `remotes[].url` field in MCP registry search results.
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
If the server requires credentials, a `SetupRequirementsResponse` is returned with an OAuth
|
||||||
|
login prompt. Once the user completes the flow and confirms, retry the same call immediately.
|
||||||
|
|
||||||
|
### Communication style
|
||||||
|
|
||||||
|
Avoid technical jargon like "MCP server", "OAuth", or "credentials" when talking to the user.
|
||||||
|
Use plain, friendly language instead:
|
||||||
|
|
||||||
|
| Instead of… | Say… |
|
||||||
|
|---|---|
|
||||||
|
| "Let me connect to Sentry's MCP server and discover what tools are available." | "I can connect to Sentry and help identify important issues." |
|
||||||
|
| "Let me connect to Sentry's MCP server now." | "Next, I'll connect to Sentry." |
|
||||||
|
| "The MCP server at mcp.sentry.dev requires authentication. Please connect your credentials to continue." | "To continue, sign in to Sentry and approve access." |
|
||||||
|
| "Sentry's MCP server needs OAuth authentication. You should see a prompt to connect your Sentry account…" | "You should see a prompt to sign in to Sentry. Once connected, I can help surface critical issues right away." |
|
||||||
|
|
||||||
|
Use **"connect to [Service]"** or **"sign in to [Service]"** — never "MCP server", "OAuth", or "credentials".
|
||||||
@@ -536,10 +536,12 @@ async def test_wait_for_stash_signaled():
|
|||||||
result = await wait_for_stash(timeout=1.0)
|
result = await wait_for_stash(timeout=1.0)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
pto = _pto.get()
|
||||||
|
assert pto is not None
|
||||||
|
assert pto.get("WebSearch") == ["result data"]
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
_pto.set({}) # type: ignore[arg-type]
|
_pto.set({})
|
||||||
_stash_event.set(None)
|
_stash_event.set(None)
|
||||||
|
|
||||||
|
|
||||||
@@ -554,7 +556,7 @@ async def test_wait_for_stash_timeout():
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
_pto.set({}) # type: ignore[arg-type]
|
_pto.set({})
|
||||||
_stash_event.set(None)
|
_stash_event.set(None)
|
||||||
|
|
||||||
|
|
||||||
@@ -573,10 +575,12 @@ async def test_wait_for_stash_already_stashed():
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# But the stash itself is populated
|
# But the stash itself is populated
|
||||||
assert _pto.get({}).get("Read") == ["file contents"]
|
pto = _pto.get()
|
||||||
|
assert pto is not None
|
||||||
|
assert pto.get("Read") == ["file contents"]
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
_pto.set({}) # type: ignore[arg-type]
|
_pto.set({})
|
||||||
_stash_event.set(None)
|
_stash_event.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,13 @@ import re
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from backend.copilot.context import is_allowed_local_path
|
||||||
|
|
||||||
from .tool_adapter import (
|
from .tool_adapter import (
|
||||||
BLOCKED_TOOLS,
|
BLOCKED_TOOLS,
|
||||||
DANGEROUS_PATTERNS,
|
DANGEROUS_PATTERNS,
|
||||||
MCP_TOOL_PREFIX,
|
MCP_TOOL_PREFIX,
|
||||||
WORKSPACE_SCOPED_TOOLS,
|
WORKSPACE_SCOPED_TOOLS,
|
||||||
is_allowed_local_path,
|
|
||||||
stash_pending_tool_output,
|
stash_pending_tool_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ def create_security_hooks(
|
|||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
sdk_cwd: str | None = None,
|
sdk_cwd: str | None = None,
|
||||||
max_subtasks: int = 3,
|
max_subtasks: int = 3,
|
||||||
on_compact: Callable[[], None] | None = None,
|
on_compact: Callable[[str], None] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create the security hooks configuration for Claude Agent SDK.
|
"""Create the security hooks configuration for Claude Agent SDK.
|
||||||
|
|
||||||
@@ -141,6 +142,7 @@ def create_security_hooks(
|
|||||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||||
on_compact: Callback invoked when SDK starts compacting context.
|
on_compact: Callback invoked when SDK starts compacting context.
|
||||||
|
Receives the transcript_path from the hook input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Hooks configuration dict for ClaudeAgentOptions
|
Hooks configuration dict for ClaudeAgentOptions
|
||||||
@@ -300,11 +302,21 @@ def create_security_hooks(
|
|||||||
"""
|
"""
|
||||||
_ = context, tool_use_id
|
_ = context, tool_use_id
|
||||||
trigger = input_data.get("trigger", "auto")
|
trigger = input_data.get("trigger", "auto")
|
||||||
|
# Sanitize untrusted input before logging to prevent log injection
|
||||||
|
transcript_path = (
|
||||||
|
str(input_data.get("transcript_path", ""))
|
||||||
|
.replace("\n", "")
|
||||||
|
.replace("\r", "")
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||||
|
"transcript_path=%s",
|
||||||
|
trigger,
|
||||||
|
user_id,
|
||||||
|
transcript_path,
|
||||||
)
|
)
|
||||||
if on_compact is not None:
|
if on_compact is not None:
|
||||||
on_compact()
|
on_compact(transcript_path)
|
||||||
return cast(SyncHookJSONOutput, {})
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
hooks: dict[str, Any] = {
|
hooks: dict[str, Any] = {
|
||||||
|
|||||||
@@ -9,8 +9,9 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.copilot.context import _current_project_dir
|
||||||
|
|
||||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||||
from .service import _is_tool_error_or_denial
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-abc123"
|
SDK_CWD = "/tmp/copilot-abc123"
|
||||||
|
|
||||||
@@ -120,8 +121,6 @@ def test_read_no_cwd_denies_absolute():
|
|||||||
|
|
||||||
|
|
||||||
def test_read_tool_results_allowed():
|
def test_read_tool_results_allowed():
|
||||||
from .tool_adapter import _current_project_dir
|
|
||||||
|
|
||||||
home = os.path.expanduser("~")
|
home = os.path.expanduser("~")
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||||
@@ -133,16 +132,14 @@ def test_read_tool_results_allowed():
|
|||||||
_current_project_dir.reset(token)
|
_current_project_dir.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def test_read_claude_projects_session_dir_allowed():
|
def test_read_claude_projects_settings_json_denied():
|
||||||
"""Files within the current session's project dir are allowed."""
|
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||||
from .tool_adapter import _current_project_dir
|
|
||||||
|
|
||||||
home = os.path.expanduser("~")
|
home = os.path.expanduser("~")
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||||
try:
|
try:
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||||
assert not _is_denied(result)
|
assert _is_denied(result)
|
||||||
finally:
|
finally:
|
||||||
_current_project_dir.reset(token)
|
_current_project_dir.reset(token)
|
||||||
|
|
||||||
@@ -357,76 +354,3 @@ async def test_task_slot_released_on_failure(_hooks):
|
|||||||
context={},
|
context={},
|
||||||
)
|
)
|
||||||
assert not _is_denied(result)
|
assert not _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
# -- _is_tool_error_or_denial ------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestIsToolErrorOrDenial:
|
|
||||||
def test_none_content(self):
|
|
||||||
assert _is_tool_error_or_denial(None) is False
|
|
||||||
|
|
||||||
def test_empty_content(self):
|
|
||||||
assert _is_tool_error_or_denial("") is False
|
|
||||||
|
|
||||||
def test_benign_output(self):
|
|
||||||
assert _is_tool_error_or_denial("All good, no issues.") is False
|
|
||||||
|
|
||||||
def test_security_marker(self):
|
|
||||||
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
|
|
||||||
|
|
||||||
def test_cannot_be_bypassed(self):
|
|
||||||
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
|
|
||||||
|
|
||||||
def test_not_allowed(self):
|
|
||||||
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
|
|
||||||
|
|
||||||
def test_background_task_denial(self):
|
|
||||||
assert (
|
|
||||||
_is_tool_error_or_denial(
|
|
||||||
"Background task execution is not supported. "
|
|
||||||
"Run tasks in the foreground instead."
|
|
||||||
)
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_subtask_limit_denial(self):
|
|
||||||
assert (
|
|
||||||
_is_tool_error_or_denial(
|
|
||||||
"Maximum 2 concurrent sub-tasks. "
|
|
||||||
"Wait for running sub-tasks to finish, "
|
|
||||||
"or continue in the main conversation."
|
|
||||||
)
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_denied_marker(self):
|
|
||||||
assert (
|
|
||||||
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_blocked_marker(self):
|
|
||||||
assert _is_tool_error_or_denial("Request blocked by security policy") is True
|
|
||||||
|
|
||||||
def test_failed_marker(self):
|
|
||||||
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
|
|
||||||
|
|
||||||
def test_mcp_iserror(self):
|
|
||||||
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
|
|
||||||
|
|
||||||
def test_benign_error_in_value(self):
|
|
||||||
"""Content like '0 errors found' should not trigger — 'error' was removed."""
|
|
||||||
assert _is_tool_error_or_denial("0 errors found") is False
|
|
||||||
|
|
||||||
def test_benign_permission_field(self):
|
|
||||||
"""Schema descriptions mentioning 'permission' should not trigger."""
|
|
||||||
assert (
|
|
||||||
_is_tool_error_or_denial(
|
|
||||||
'{"fields": [{"name": "permission_level", "type": "int"}]}'
|
|
||||||
)
|
|
||||||
is False
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_benign_not_found_in_listing(self):
|
|
||||||
"""File listing containing 'not found' in filenames should not trigger."""
|
|
||||||
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from langfuse import propagate_attributes
|
|||||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.copilot.context import get_workspace_manager
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.executor.cluster_lock import AsyncClusterLock
|
from backend.executor.cluster_lock import AsyncClusterLock
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
@@ -60,9 +61,8 @@ from ..service import (
|
|||||||
_generate_session_title,
|
_generate_session_title,
|
||||||
_is_langfuse_configured,
|
_is_langfuse_configured,
|
||||||
)
|
)
|
||||||
from ..tools.e2b_sandbox import get_or_create_sandbox
|
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||||
from ..tools.workspace_files import get_manager
|
|
||||||
from ..tracking import track_user_message
|
from ..tracking import track_user_message
|
||||||
from .compaction import CompactionTracker, filter_compaction_messages
|
from .compaction import CompactionTracker, filter_compaction_messages
|
||||||
from .response_adapter import SDKResponseAdapter
|
from .response_adapter import SDKResponseAdapter
|
||||||
@@ -77,6 +77,7 @@ from .tool_adapter import (
|
|||||||
from .transcript import (
|
from .transcript import (
|
||||||
cleanup_cli_project_dir,
|
cleanup_cli_project_dir,
|
||||||
download_transcript,
|
download_transcript,
|
||||||
|
read_compacted_entries,
|
||||||
upload_transcript,
|
upload_transcript,
|
||||||
validate_transcript,
|
validate_transcript,
|
||||||
write_transcript_to_tempfile,
|
write_transcript_to_tempfile,
|
||||||
@@ -456,31 +457,6 @@ def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
|||||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||||
|
|
||||||
|
|
||||||
def _is_tool_error_or_denial(content: str | None) -> bool:
|
|
||||||
"""Check if a tool message content indicates an error or denial.
|
|
||||||
|
|
||||||
Currently unused — ``_format_conversation_context`` includes all tool
|
|
||||||
results. Kept as a utility for future selective filtering.
|
|
||||||
"""
|
|
||||||
if not content:
|
|
||||||
return False
|
|
||||||
lower = content.lower()
|
|
||||||
return any(
|
|
||||||
marker in lower
|
|
||||||
for marker in (
|
|
||||||
"[security]",
|
|
||||||
"cannot be bypassed",
|
|
||||||
"not allowed",
|
|
||||||
"not supported", # background-task denial
|
|
||||||
"maximum", # subtask-limit denial
|
|
||||||
"denied",
|
|
||||||
"blocked",
|
|
||||||
"failed to", # internal tool execution failures
|
|
||||||
'"iserror": true', # MCP protocol error flag
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _build_query_message(
|
async def _build_query_message(
|
||||||
current_message: str,
|
current_message: str,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
@@ -589,7 +565,7 @@ async def _prepare_file_attachments(
|
|||||||
return empty
|
return empty
|
||||||
|
|
||||||
try:
|
try:
|
||||||
manager = await get_manager(user_id, session_id)
|
manager = await get_workspace_manager(user_id, session_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to create workspace manager for file attachments",
|
"Failed to create workspace manager for file attachments",
|
||||||
@@ -784,28 +760,29 @@ async def stream_chat_completion_sdk(
|
|||||||
|
|
||||||
async def _setup_e2b():
|
async def _setup_e2b():
|
||||||
"""Set up E2B sandbox if configured, return sandbox or None."""
|
"""Set up E2B sandbox if configured, return sandbox or None."""
|
||||||
if config.use_e2b_sandbox and not config.e2b_api_key:
|
if not (e2b_api_key := config.active_e2b_api_key):
|
||||||
logger.warning(
|
if config.use_e2b_sandbox:
|
||||||
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
logger.warning(
|
||||||
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
"[E2B] [%s] E2B sandbox enabled but no API key configured "
|
||||||
session_id[:12],
|
"(CHAT_E2B_API_KEY / E2B_API_KEY) — falling back to bubblewrap",
|
||||||
)
|
|
||||||
return None
|
|
||||||
if config.use_e2b_sandbox and config.e2b_api_key:
|
|
||||||
try:
|
|
||||||
return await get_or_create_sandbox(
|
|
||||||
session_id,
|
|
||||||
api_key=config.e2b_api_key,
|
|
||||||
template=config.e2b_sandbox_template,
|
|
||||||
timeout=config.e2b_sandbox_timeout,
|
|
||||||
)
|
|
||||||
except Exception as e2b_err:
|
|
||||||
logger.error(
|
|
||||||
"[E2B] [%s] Setup failed: %s",
|
|
||||||
session_id[:12],
|
session_id[:12],
|
||||||
e2b_err,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return await get_or_create_sandbox(
|
||||||
|
session_id,
|
||||||
|
api_key=e2b_api_key,
|
||||||
|
template=config.e2b_sandbox_template,
|
||||||
|
timeout=config.e2b_sandbox_timeout,
|
||||||
|
on_timeout=config.e2b_sandbox_on_timeout,
|
||||||
|
)
|
||||||
|
except Exception as e2b_err:
|
||||||
|
logger.error(
|
||||||
|
"[E2B] [%s] Setup failed: %s",
|
||||||
|
session_id[:12],
|
||||||
|
e2b_err,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _fetch_transcript():
|
async def _fetch_transcript():
|
||||||
@@ -837,7 +814,6 @@ async def stream_chat_completion_sdk(
|
|||||||
system_prompt = base_system_prompt + get_sdk_supplement(
|
system_prompt = base_system_prompt + get_sdk_supplement(
|
||||||
use_e2b=use_e2b, cwd=sdk_cwd
|
use_e2b=use_e2b, cwd=sdk_cwd
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process transcript download result
|
# Process transcript download result
|
||||||
transcript_msg_count = 0
|
transcript_msg_count = 0
|
||||||
if dl:
|
if dl:
|
||||||
@@ -902,6 +878,11 @@ async def stream_chat_completion_sdk(
|
|||||||
|
|
||||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||||
|
|
||||||
|
def _on_stderr(line: str) -> None:
|
||||||
|
sid = session_id[:12] if session_id else "?"
|
||||||
|
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
|
||||||
|
|
||||||
sdk_options_kwargs: dict[str, Any] = {
|
sdk_options_kwargs: dict[str, Any] = {
|
||||||
"system_prompt": system_prompt,
|
"system_prompt": system_prompt,
|
||||||
"mcp_servers": {"copilot": mcp_server},
|
"mcp_servers": {"copilot": mcp_server},
|
||||||
@@ -910,6 +891,7 @@ async def stream_chat_completion_sdk(
|
|||||||
"hooks": security_hooks,
|
"hooks": security_hooks,
|
||||||
"cwd": sdk_cwd,
|
"cwd": sdk_cwd,
|
||||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||||
|
"stderr": _on_stderr,
|
||||||
}
|
}
|
||||||
if sdk_model:
|
if sdk_model:
|
||||||
sdk_options_kwargs["model"] = sdk_model
|
sdk_options_kwargs["model"] = sdk_model
|
||||||
@@ -1064,6 +1046,7 @@ async def stream_chat_completion_sdk(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
ended_with_stream_error = True
|
ended_with_stream_error = True
|
||||||
|
|
||||||
yield StreamError(
|
yield StreamError(
|
||||||
errorText=f"SDK stream error: {stream_err}",
|
errorText=f"SDK stream error: {stream_err}",
|
||||||
code="sdk_stream_error",
|
code="sdk_stream_error",
|
||||||
@@ -1082,6 +1065,19 @@ async def stream_chat_completion_sdk(
|
|||||||
len(adapter.resolved_tool_calls),
|
len(adapter.resolved_tool_calls),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log AssistantMessage API errors (e.g. invalid_request)
|
||||||
|
# so we can debug Anthropic API 400s surfaced by the CLI.
|
||||||
|
sdk_error = getattr(sdk_msg, "error", None)
|
||||||
|
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
|
||||||
|
logger.error(
|
||||||
|
"[SDK] [%s] AssistantMessage has error=%s, "
|
||||||
|
"content_blocks=%d, content_preview=%s",
|
||||||
|
session_id[:12],
|
||||||
|
sdk_error,
|
||||||
|
len(sdk_msg.content),
|
||||||
|
str(sdk_msg.content)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||||
# executed asynchronously via start_soon() — the next
|
# executed asynchronously via start_soon() — the next
|
||||||
# message can arrive before the hook stashes output.
|
# message can arrive before the hook stashes output.
|
||||||
@@ -1135,9 +1131,26 @@ async def stream_chat_completion_sdk(
|
|||||||
sdk_msg.result or "(no error message provided)",
|
sdk_msg.result or "(no error message provided)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit compaction end if SDK finished compacting
|
# Emit compaction end if SDK finished compacting.
|
||||||
for ev in await compaction.emit_end_if_ready(session):
|
# When compaction ends, sync TranscriptBuilder with the
|
||||||
|
# CLI's active context so they stay identical.
|
||||||
|
compact_result = await compaction.emit_end_if_ready(session)
|
||||||
|
for ev in compact_result.events:
|
||||||
yield ev
|
yield ev
|
||||||
|
# After replace_entries, skip append_assistant for this
|
||||||
|
# sdk_msg — the CLI session file already contains it,
|
||||||
|
# so appending again would create a duplicate.
|
||||||
|
entries_replaced = False
|
||||||
|
if compact_result.just_ended:
|
||||||
|
compacted = await asyncio.to_thread(
|
||||||
|
read_compacted_entries,
|
||||||
|
compact_result.transcript_path,
|
||||||
|
)
|
||||||
|
if compacted is not None:
|
||||||
|
transcript_builder.replace_entries(
|
||||||
|
compacted, log_prefix=log_prefix
|
||||||
|
)
|
||||||
|
entries_replaced = True
|
||||||
|
|
||||||
for response in adapter.convert_message(sdk_msg):
|
for response in adapter.convert_message(sdk_msg):
|
||||||
if isinstance(response, StreamStart):
|
if isinstance(response, StreamStart):
|
||||||
@@ -1224,10 +1237,11 @@ async def stream_chat_completion_sdk(
|
|||||||
tool_call_id=response.toolCallId,
|
tool_call_id=response.toolCallId,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
transcript_builder.append_tool_result(
|
if not entries_replaced:
|
||||||
tool_use_id=response.toolCallId,
|
transcript_builder.append_tool_result(
|
||||||
content=content,
|
tool_use_id=response.toolCallId,
|
||||||
)
|
content=content,
|
||||||
|
)
|
||||||
has_tool_results = True
|
has_tool_results = True
|
||||||
|
|
||||||
elif isinstance(response, StreamFinish):
|
elif isinstance(response, StreamFinish):
|
||||||
@@ -1237,7 +1251,9 @@ async def stream_chat_completion_sdk(
|
|||||||
# any stashed tool results from the previous turn are
|
# any stashed tool results from the previous turn are
|
||||||
# recorded first, preserving the required API order:
|
# recorded first, preserving the required API order:
|
||||||
# assistant(tool_use) → tool_result → assistant(text).
|
# assistant(tool_use) → tool_result → assistant(text).
|
||||||
if isinstance(sdk_msg, AssistantMessage):
|
# Skip if replace_entries just ran — the CLI session
|
||||||
|
# file already contains this message.
|
||||||
|
if isinstance(sdk_msg, AssistantMessage) and not entries_replaced:
|
||||||
transcript_builder.append_assistant(
|
transcript_builder.append_assistant(
|
||||||
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
|
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
|
||||||
model=sdk_msg.model,
|
model=sdk_msg.model,
|
||||||
@@ -1416,14 +1432,25 @@ async def stream_chat_completion_sdk(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --- Pause E2B sandbox to stop billing between turns ---
|
||||||
|
# Fire-and-forget: pausing is best-effort and must not block the
|
||||||
|
# response or the transcript upload. The task is anchored to
|
||||||
|
# _background_tasks to prevent garbage collection.
|
||||||
|
# Use pause_sandbox_direct to skip the Redis lookup and reconnect
|
||||||
|
# round-trip — e2b_sandbox is the live object from this turn.
|
||||||
|
if e2b_sandbox is not None:
|
||||||
|
task = asyncio.create_task(pause_sandbox_direct(e2b_sandbox, session_id))
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
# --- Upload transcript for next-turn --resume ---
|
# --- Upload transcript for next-turn --resume ---
|
||||||
# This MUST run in finally so the transcript is uploaded even when
|
# TranscriptBuilder is the single source of truth. It mirrors the
|
||||||
# the streaming loop raises an exception.
|
# CLI's active context: on compaction, replace_entries() syncs it
|
||||||
# The transcript represents the COMPLETE active context (atomic).
|
# with the compacted session file. No CLI file read needed here.
|
||||||
if config.claude_agent_use_resume and user_id and session is not None:
|
if config.claude_agent_use_resume and user_id and session is not None:
|
||||||
try:
|
try:
|
||||||
# Build complete transcript from captured SDK messages
|
|
||||||
transcript_content = transcript_builder.to_jsonl()
|
transcript_content = transcript_builder.to_jsonl()
|
||||||
|
entry_count = transcript_builder.entry_count
|
||||||
|
|
||||||
if not transcript_content:
|
if not transcript_content:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1433,18 +1460,15 @@ async def stream_chat_completion_sdk(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"%s Transcript invalid, skipping upload (entries=%d)",
|
"%s Transcript invalid, skipping upload (entries=%d)",
|
||||||
log_prefix,
|
log_prefix,
|
||||||
transcript_builder.entry_count,
|
entry_count,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"%s Uploading complete transcript (entries=%d, bytes=%d)",
|
"%s Uploading transcript (entries=%d, bytes=%d)",
|
||||||
log_prefix,
|
log_prefix,
|
||||||
transcript_builder.entry_count,
|
entry_count,
|
||||||
len(transcript_content),
|
len(transcript_content),
|
||||||
)
|
)
|
||||||
# Shield upload from cancellation - let it complete even if
|
|
||||||
# the finally block is interrupted. No timeout to avoid race
|
|
||||||
# conditions where backgrounded uploads overwrite newer transcripts.
|
|
||||||
await asyncio.shield(
|
await asyncio.shield(
|
||||||
upload_transcript(
|
upload_transcript(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Tests for SDK service helpers."""
|
"""Tests for SDK service helpers."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -19,7 +20,7 @@ class _FakeFileInfo:
|
|||||||
size_bytes: int
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
|
||||||
|
|
||||||
|
|
||||||
class TestPrepareFileAttachments:
|
class TestPrepareFileAttachments:
|
||||||
@@ -212,7 +213,7 @@ class TestPromptSupplement:
|
|||||||
|
|
||||||
# Workflows are now in individual tool descriptions (not separate sections)
|
# Workflows are now in individual tool descriptions (not separate sections)
|
||||||
# Check that key workflow concepts appear in tool descriptions
|
# Check that key workflow concepts appear in tool descriptions
|
||||||
assert "suggested_goal" in docs or "clarifying_questions" in docs
|
assert "agent_json" in docs or "find_block" in docs
|
||||||
assert "run_mcp_tool" in docs
|
assert "run_mcp_tool" in docs
|
||||||
|
|
||||||
def test_baseline_supplement_completeness(self):
|
def test_baseline_supplement_completeness(self):
|
||||||
@@ -231,6 +232,48 @@ class TestPromptSupplement:
|
|||||||
f"`{tool_name}`" in docs
|
f"`{tool_name}`" in docs
|
||||||
), f"Tool '{tool_name}' missing from baseline supplement"
|
), f"Tool '{tool_name}' missing from baseline supplement"
|
||||||
|
|
||||||
|
def test_pause_task_scheduled_before_transcript_upload(self):
|
||||||
|
"""Pause is scheduled as a background task before transcript upload begins.
|
||||||
|
|
||||||
|
The finally block in stream_response_sdk does:
|
||||||
|
(1) asyncio.create_task(pause_sandbox_direct(...)) — fire-and-forget
|
||||||
|
(2) await asyncio.shield(upload_transcript(...)) — awaited
|
||||||
|
|
||||||
|
Scheduling pause via create_task before awaiting upload ensures:
|
||||||
|
- Pause never blocks transcript upload (billing stops concurrently)
|
||||||
|
- On E2B timeout, pause silently fails; upload proceeds unaffected
|
||||||
|
"""
|
||||||
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
async def _mock_pause(sandbox, session_id):
|
||||||
|
call_order.append("pause")
|
||||||
|
|
||||||
|
async def _mock_upload(**kwargs):
|
||||||
|
call_order.append("upload")
|
||||||
|
|
||||||
|
async def _simulate_teardown():
|
||||||
|
"""Mirror the service.py finally block teardown sequence."""
|
||||||
|
sandbox = MagicMock()
|
||||||
|
|
||||||
|
# (1) Schedule pause — mirrors lines ~1427-1429 in service.py
|
||||||
|
task = asyncio.create_task(_mock_pause(sandbox, "test-sess"))
|
||||||
|
|
||||||
|
# (2) Await transcript upload — mirrors lines ~1460-1468 in service.py
|
||||||
|
# Yielding to the event loop here lets the pause task start concurrently.
|
||||||
|
await _mock_upload(
|
||||||
|
user_id="u", session_id="test-sess", content="x", message_count=1
|
||||||
|
)
|
||||||
|
await task
|
||||||
|
|
||||||
|
asyncio.run(_simulate_teardown())
|
||||||
|
|
||||||
|
# Both must run; pause is scheduled before upload starts
|
||||||
|
assert "pause" in call_order
|
||||||
|
assert "upload" in call_order
|
||||||
|
# create_task schedules pause, then upload is awaited — pause runs
|
||||||
|
# concurrently during upload's first yield. The ordering guarantee is
|
||||||
|
# that create_task is CALLED before upload is AWAITED (see source order).
|
||||||
|
|
||||||
def test_baseline_supplement_no_duplicate_tools(self):
|
def test_baseline_supplement_no_duplicate_tools(self):
|
||||||
"""No tool should appear multiple times in baseline supplement."""
|
"""No tool should appear multiple times in baseline supplement."""
|
||||||
from backend.copilot.prompting import get_baseline_supplement
|
from backend.copilot.prompting import get_baseline_supplement
|
||||||
|
|||||||
@@ -9,14 +9,29 @@ import itertools
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
from backend.copilot.context import (
|
||||||
|
_current_project_dir,
|
||||||
|
_current_sandbox,
|
||||||
|
_current_sdk_cwd,
|
||||||
|
_current_session,
|
||||||
|
_current_user_id,
|
||||||
|
_encode_cwd_for_cli,
|
||||||
|
get_execution_context,
|
||||||
|
get_sdk_cwd,
|
||||||
|
is_allowed_local_path,
|
||||||
|
)
|
||||||
from backend.copilot.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.copilot.sdk.file_ref import (
|
||||||
|
FileRefExpansionError,
|
||||||
|
expand_file_refs_in_args,
|
||||||
|
read_file_bytes,
|
||||||
|
)
|
||||||
from backend.copilot.tools import TOOL_REGISTRY
|
from backend.copilot.tools import TOOL_REGISTRY
|
||||||
from backend.copilot.tools.base import BaseTool
|
from backend.copilot.tools.base import BaseTool
|
||||||
from backend.util.truncate import truncate
|
from backend.util.truncate import truncate
|
||||||
@@ -28,84 +43,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
|
||||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
|
||||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
|
||||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
|
||||||
|
|
||||||
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
|
||||||
_MCP_MAX_CHARS = 500_000
|
_MCP_MAX_CHARS = 500_000
|
||||||
|
|
||||||
# Context variable holding the encoded project directory name for the current
|
|
||||||
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
|
|
||||||
# so that path validation can scope tool-results reads to the current session.
|
|
||||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
|
||||||
"""Encode a working directory path the same way the Claude CLI does.
|
|
||||||
|
|
||||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
|
||||||
"""
|
|
||||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
|
||||||
|
|
||||||
|
|
||||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
|
||||||
"""Check whether *path* is an allowed host-filesystem path.
|
|
||||||
|
|
||||||
Allowed:
|
|
||||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
|
||||||
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
|
|
||||||
project directory for this session (tool-results, transcripts, etc.)
|
|
||||||
|
|
||||||
Both checks are scoped to the **current session** so sessions cannot
|
|
||||||
read each other's data.
|
|
||||||
"""
|
|
||||||
if not path:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if path.startswith("~"):
|
|
||||||
resolved = os.path.realpath(os.path.expanduser(path))
|
|
||||||
elif not os.path.isabs(path) and sdk_cwd:
|
|
||||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
|
||||||
else:
|
|
||||||
resolved = os.path.realpath(path)
|
|
||||||
|
|
||||||
# Allow access within the SDK working directory
|
|
||||||
if sdk_cwd:
|
|
||||||
norm_cwd = os.path.realpath(sdk_cwd)
|
|
||||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Allow access within the current session's CLI project directory
|
|
||||||
# (~/.claude/projects/<encoded-cwd>/).
|
|
||||||
encoded = _current_project_dir.get("")
|
|
||||||
if encoded:
|
|
||||||
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
|
|
||||||
if resolved == session_project or resolved.startswith(session_project + os.sep):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||||
MCP_SERVER_NAME = "copilot"
|
MCP_SERVER_NAME = "copilot"
|
||||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||||
|
|
||||||
# Context variables to pass user/session info to tool execution
|
|
||||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
|
||||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
|
||||||
"current_session", default=None
|
|
||||||
)
|
|
||||||
# E2B cloud sandbox for the current turn (None when E2B is not configured).
|
|
||||||
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
|
|
||||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
|
||||||
"_current_sandbox", default=None
|
|
||||||
)
|
|
||||||
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
|
|
||||||
# Used by workspace tools to save binary files for the CLI's built-in Read.
|
|
||||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
|
||||||
|
|
||||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||||
# response adapter when it builds StreamToolOutputAvailable.
|
# response adapter when it builds StreamToolOutputAvailable.
|
||||||
@@ -149,24 +93,6 @@ def set_execution_context(
|
|||||||
_stash_event.set(asyncio.Event())
|
_stash_event.set(asyncio.Event())
|
||||||
|
|
||||||
|
|
||||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
|
||||||
"""Return the E2B sandbox for the current turn, or None."""
|
|
||||||
return _current_sandbox.get()
|
|
||||||
|
|
||||||
|
|
||||||
def get_sdk_cwd() -> str:
|
|
||||||
"""Return the SDK ephemeral working directory for the current turn."""
|
|
||||||
return _current_sdk_cwd.get()
|
|
||||||
|
|
||||||
|
|
||||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
|
||||||
"""Get the current execution context."""
|
|
||||||
return (
|
|
||||||
_current_user_id.get(),
|
|
||||||
_current_session.get(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||||
"""Pop and return the oldest stashed output for *tool_name*.
|
"""Pop and return the oldest stashed output for *tool_name*.
|
||||||
|
|
||||||
@@ -259,7 +185,11 @@ async def _execute_tool_sync(
|
|||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
args: dict[str, Any],
|
args: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
"""Execute a tool synchronously and return MCP-formatted response.
|
||||||
|
|
||||||
|
Note: ``@@agptfile:`` expansion is handled upstream in the ``_truncating`` wrapper
|
||||||
|
so all registered handlers (BaseTool, E2B, Read) expand uniformly.
|
||||||
|
"""
|
||||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||||
result = await base_tool.execute(
|
result = await base_tool.execute(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -320,42 +250,50 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Read a local file with optional offset/limit.
|
"""Read a file with optional offset/limit.
|
||||||
|
|
||||||
Only allows paths that pass :func:`is_allowed_local_path` — the current
|
Supports ``workspace://`` URIs (delegated to the workspace manager) and
|
||||||
session's tool-results directory and ephemeral working directory.
|
local paths within the session's allowed directories (sdk_cwd + tool-results).
|
||||||
"""
|
"""
|
||||||
file_path = args.get("file_path", "")
|
file_path = args.get("file_path", "")
|
||||||
offset = args.get("offset", 0)
|
offset = max(0, int(args.get("offset", 0)))
|
||||||
limit = args.get("limit", 2000)
|
limit = max(1, int(args.get("limit", 2000)))
|
||||||
|
|
||||||
if not is_allowed_local_path(file_path):
|
def _mcp_err(text: str) -> dict[str, Any]:
|
||||||
return {
|
return {"content": [{"type": "text", "text": text}], "isError": True}
|
||||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
|
||||||
"isError": True,
|
def _mcp_ok(text: str) -> dict[str, Any]:
|
||||||
}
|
return {"content": [{"type": "text", "text": text}], "isError": False}
|
||||||
|
|
||||||
|
if file_path.startswith("workspace://"):
|
||||||
|
user_id, session = get_execution_context()
|
||||||
|
if session is None:
|
||||||
|
return _mcp_err("workspace:// file references require an active session")
|
||||||
|
try:
|
||||||
|
raw = await read_file_bytes(file_path, user_id, session)
|
||||||
|
except ValueError as exc:
|
||||||
|
return _mcp_err(str(exc))
|
||||||
|
lines = raw.decode("utf-8", errors="replace").splitlines(keepends=True)
|
||||||
|
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||||
|
numbered = "".join(
|
||||||
|
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||||
|
)
|
||||||
|
return _mcp_ok(numbered)
|
||||||
|
|
||||||
|
if not is_allowed_local_path(file_path, get_sdk_cwd()):
|
||||||
|
return _mcp_err(f"Path not allowed: {file_path}")
|
||||||
|
|
||||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||||
try:
|
try:
|
||||||
with open(resolved) as f:
|
with open(resolved) as f:
|
||||||
selected = list(itertools.islice(f, offset, offset + limit))
|
selected = list(itertools.islice(f, offset, offset + limit))
|
||||||
content = "".join(selected)
|
|
||||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||||
# don't delete here — the SDK may read in multiple chunks.
|
# don't delete here — the SDK may read in multiple chunks.
|
||||||
return {
|
return _mcp_ok("".join(selected))
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return {
|
return _mcp_err(f"File not found: {file_path}")
|
||||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return _mcp_err(f"Error reading file: {e}")
|
||||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_READ_TOOL_NAME = "Read"
|
_READ_TOOL_NAME = "Read"
|
||||||
@@ -409,14 +347,30 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
|||||||
:func:`get_sdk_disallowed_tools`.
|
:func:`get_sdk_disallowed_tools`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _truncating(fn, tool_name: str):
|
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
|
||||||
"""Wrap a tool handler so its response is truncated to stay under the
|
"""Wrap a tool handler so its response is truncated to stay under the
|
||||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||||
response adapter before the SDK can apply its own head-truncation.
|
response adapter before the SDK can apply its own head-truncation.
|
||||||
|
|
||||||
|
Also expands ``@@agptfile:`` references in args so every registered tool
|
||||||
|
(BaseTool, E2B file tools, Read) receives resolved content uniformly.
|
||||||
|
|
||||||
Applied once to every registered tool."""
|
Applied once to every registered tool."""
|
||||||
|
|
||||||
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
user_id, session = get_execution_context()
|
||||||
|
if session is not None:
|
||||||
|
try:
|
||||||
|
args = await expand_file_refs_in_args(
|
||||||
|
args, user_id, session, input_schema=input_schema
|
||||||
|
)
|
||||||
|
except FileRefExpansionError as exc:
|
||||||
|
return _mcp_error(
|
||||||
|
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||||
|
"Ensure the file exists before referencing it. "
|
||||||
|
"For sandbox paths use bash_exec to verify the file exists first; "
|
||||||
|
"for workspace files use a workspace:// URI."
|
||||||
|
)
|
||||||
result = await fn(args)
|
result = await fn(args)
|
||||||
truncated = truncate(result, _MCP_MAX_CHARS)
|
truncated = truncate(result, _MCP_MAX_CHARS)
|
||||||
|
|
||||||
@@ -437,11 +391,12 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
|||||||
|
|
||||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
handler = create_tool_handler(base_tool)
|
handler = create_tool_handler(base_tool)
|
||||||
|
schema = _build_input_schema(base_tool)
|
||||||
decorated = tool(
|
decorated = tool(
|
||||||
tool_name,
|
tool_name,
|
||||||
base_tool.description,
|
base_tool.description,
|
||||||
_build_input_schema(base_tool),
|
schema,
|
||||||
)(_truncating(handler, tool_name))
|
)(_truncating(handler, tool_name, input_schema=schema))
|
||||||
sdk_tools.append(decorated)
|
sdk_tools.append(decorated)
|
||||||
|
|
||||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.copilot.context import get_sdk_cwd
|
||||||
from backend.util.truncate import truncate
|
from backend.util.truncate import truncate
|
||||||
|
|
||||||
from .tool_adapter import (
|
from .tool_adapter import (
|
||||||
_MCP_MAX_CHARS,
|
_MCP_MAX_CHARS,
|
||||||
_text_from_mcp_result,
|
_text_from_mcp_result,
|
||||||
get_sdk_cwd,
|
|
||||||
pop_pending_tool_output,
|
pop_pending_tool_output,
|
||||||
set_execution_context,
|
set_execution_context,
|
||||||
stash_pending_tool_output,
|
stash_pending_tool_output,
|
||||||
|
|||||||
@@ -13,8 +13,10 @@ filesystem for self-hosted) — no DB column needed.
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
@@ -82,7 +84,11 @@ def strip_progress_entries(content: str) -> str:
|
|||||||
parent = entry.get("parentUuid", "")
|
parent = entry.get("parentUuid", "")
|
||||||
if uid:
|
if uid:
|
||||||
uuid_to_parent[uid] = parent
|
uuid_to_parent[uid] = parent
|
||||||
if entry.get("type", "") in STRIPPABLE_TYPES and uid:
|
if (
|
||||||
|
entry.get("type", "") in STRIPPABLE_TYPES
|
||||||
|
and uid
|
||||||
|
and not entry.get("isCompactSummary")
|
||||||
|
):
|
||||||
stripped_uuids.add(uid)
|
stripped_uuids.add(uid)
|
||||||
|
|
||||||
# Second pass: keep non-stripped entries, reparenting where needed.
|
# Second pass: keep non-stripped entries, reparenting where needed.
|
||||||
@@ -106,7 +112,9 @@ def strip_progress_entries(content: str) -> str:
|
|||||||
if not isinstance(entry, dict):
|
if not isinstance(entry, dict):
|
||||||
result_lines.append(line)
|
result_lines.append(line)
|
||||||
continue
|
continue
|
||||||
if entry.get("type", "") in STRIPPABLE_TYPES:
|
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||||
|
"isCompactSummary"
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
uid = entry.get("uuid", "")
|
uid = entry.get("uuid", "")
|
||||||
if uid in reparented:
|
if uid in reparented:
|
||||||
@@ -137,6 +145,155 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
|||||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||||
|
|
||||||
|
|
||||||
|
def _projects_base() -> str:
|
||||||
|
"""Return the resolved path to the CLI's projects directory."""
|
||||||
|
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||||
|
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||||
|
|
||||||
|
|
||||||
|
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||||
|
"""Return the CLI's project directory for a given working directory.
|
||||||
|
|
||||||
|
Returns ``None`` if the path would escape the projects base.
|
||||||
|
"""
|
||||||
|
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||||
|
projects_base = _projects_base()
|
||||||
|
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||||
|
|
||||||
|
if not project_dir.startswith(projects_base + os.sep):
|
||||||
|
logger.warning(
|
||||||
|
"[Transcript] Project dir escaped projects base: %s", project_dir
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return project_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
|
||||||
|
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
|
||||||
|
try:
|
||||||
|
resolved_base = Path(project_dir).resolve()
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
result: list[Path] = []
|
||||||
|
for candidate in Path(project_dir).glob("*.jsonl"):
|
||||||
|
try:
|
||||||
|
resolved = candidate.resolve()
|
||||||
|
if resolved.is_relative_to(resolved_base):
|
||||||
|
result.append(resolved)
|
||||||
|
except (OSError, RuntimeError) as e:
|
||||||
|
logger.debug(
|
||||||
|
"[Transcript] Skipping invalid CLI session candidate %s: %s",
|
||||||
|
candidate,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||||
|
"""Read compacted entries from the CLI session file after compaction.
|
||||||
|
|
||||||
|
Parses the JSONL file line-by-line, finds the ``isCompactSummary: true``
|
||||||
|
entry, and returns it plus all entries after it.
|
||||||
|
|
||||||
|
The CLI writes the compaction summary BEFORE sending the next message,
|
||||||
|
so the file is guaranteed to be flushed by the time we read it.
|
||||||
|
|
||||||
|
Returns a list of parsed dicts, or ``None`` if the file cannot be read
|
||||||
|
or no compaction summary is found.
|
||||||
|
"""
|
||||||
|
if not transcript_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
projects_base = _projects_base()
|
||||||
|
real_path = os.path.realpath(transcript_path)
|
||||||
|
if not real_path.startswith(projects_base + os.sep):
|
||||||
|
logger.warning(
|
||||||
|
"[Transcript] transcript_path outside projects base: %s", transcript_path
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = Path(real_path).read_text()
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(
|
||||||
|
"[Transcript] Failed to read session file %s: %s", transcript_path, e
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
compact_idx: int | None = None
|
||||||
|
|
||||||
|
for idx, line in enumerate(lines):
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
entry = json.loads(line, fallback=None)
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
if entry.get("isCompactSummary"):
|
||||||
|
compact_idx = idx # don't break — find the LAST summary
|
||||||
|
|
||||||
|
if compact_idx is None:
|
||||||
|
logger.debug("[Transcript] No compaction summary found in %s", transcript_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
entries: list[dict] = []
|
||||||
|
for line in lines[compact_idx:]:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
entry = json.loads(line, fallback=None)
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"[Transcript] Read %d compacted entries from %s (summary at line %d)",
|
||||||
|
len(entries),
|
||||||
|
transcript_path,
|
||||||
|
compact_idx + 1,
|
||||||
|
)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||||
|
"""Read the CLI's own session file, which reflects any compaction.
|
||||||
|
|
||||||
|
The CLI writes its session transcript to
|
||||||
|
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
|
||||||
|
Since each SDK turn uses a unique ``sdk_cwd``, there should be
|
||||||
|
exactly one ``.jsonl`` file in that directory.
|
||||||
|
|
||||||
|
Returns the file content, or ``None`` if not found.
|
||||||
|
"""
|
||||||
|
project_dir = _cli_project_dir(sdk_cwd)
|
||||||
|
if not project_dir or not os.path.isdir(project_dir):
|
||||||
|
return None
|
||||||
|
|
||||||
|
jsonl_files = _safe_glob_jsonl(project_dir)
|
||||||
|
if not jsonl_files:
|
||||||
|
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Pick the most recently modified file (should be only one per turn).
|
||||||
|
try:
|
||||||
|
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = session_file.read_text()
|
||||||
|
logger.info(
|
||||||
|
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||||
|
session_file,
|
||||||
|
len(content),
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||||
"""Remove the CLI's project directory for a specific working directory.
|
"""Remove the CLI's project directory for a specific working directory.
|
||||||
|
|
||||||
@@ -144,25 +301,15 @@ def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
|||||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||||
safe to remove entirely after the transcript has been uploaded.
|
safe to remove entirely after the transcript has been uploaded.
|
||||||
"""
|
"""
|
||||||
import shutil
|
project_dir = _cli_project_dir(sdk_cwd)
|
||||||
|
if not project_dir:
|
||||||
# Encode cwd the same way CLI does (replaces non-alphanumeric with -)
|
|
||||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
|
||||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
|
||||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
|
||||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
|
||||||
|
|
||||||
if not project_dir.startswith(projects_base + os.sep):
|
|
||||||
logger.warning(
|
|
||||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if os.path.isdir(project_dir):
|
if os.path.isdir(project_dir):
|
||||||
shutil.rmtree(project_dir, ignore_errors=True)
|
shutil.rmtree(project_dir, ignore_errors=True)
|
||||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||||
|
|
||||||
|
|
||||||
def write_transcript_to_tempfile(
|
def write_transcript_to_tempfile(
|
||||||
@@ -259,24 +406,27 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||||
"""Build the full storage path string that ``retrieve()`` expects.
|
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||||
|
|
||||||
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
|
||||||
``local://workspace_id/file_id/filename``. Since we use deterministic
|
|
||||||
arguments we can reconstruct the same path for download/delete without
|
|
||||||
having stored the return value.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||||
|
|
||||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
wid, fid, fname = parts
|
||||||
|
|
||||||
if isinstance(backend, GCSWorkspaceStorage):
|
if isinstance(backend, GCSWorkspaceStorage):
|
||||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||||
return f"gcs://{backend.bucket_name}/{blob}"
|
return f"gcs://{backend.bucket_name}/{blob}"
|
||||||
else:
|
return f"local://{wid}/{fid}/{fname}"
|
||||||
# LocalWorkspaceStorage returns local://{relative_path}
|
|
||||||
return f"local://{wid}/{fid}/{fname}"
|
|
||||||
|
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||||
|
"""Build the full storage path string that ``retrieve()`` expects."""
|
||||||
|
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||||
|
"""Build the full storage path for the companion .meta.json file."""
|
||||||
|
return _build_path_from_parts(
|
||||||
|
_meta_storage_path_parts(user_id, session_id), backend
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def upload_transcript(
|
async def upload_transcript(
|
||||||
@@ -381,15 +531,7 @@ async def download_transcript(
|
|||||||
message_count = 0
|
message_count = 0
|
||||||
uploaded_at = 0.0
|
uploaded_at = 0.0
|
||||||
try:
|
try:
|
||||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||||
|
|
||||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
|
||||||
if isinstance(storage, GCSWorkspaceStorage):
|
|
||||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
|
||||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
|
||||||
else:
|
|
||||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
|
||||||
|
|
||||||
meta_data = await storage.retrieve(meta_path)
|
meta_data = await storage.retrieve(meta_path)
|
||||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||||
message_count = meta.get("message_count", 0)
|
message_count = meta.get("message_count", 0)
|
||||||
@@ -406,7 +548,11 @@ async def download_transcript(
|
|||||||
|
|
||||||
|
|
||||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
"""Delete transcript and its metadata from bucket storage.
|
||||||
|
|
||||||
|
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||||
|
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||||
|
"""
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
@@ -414,6 +560,14 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await storage.delete(path)
|
await storage.delete(path)
|
||||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
logger.info("[Transcript] Deleted transcript for session %s", session_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
logger.warning("[Transcript] Failed to delete transcript: %s", e)
|
||||||
|
|
||||||
|
# Also delete the companion .meta.json to avoid orphaned metadata.
|
||||||
|
try:
|
||||||
|
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||||
|
await storage.delete(meta_path)
|
||||||
|
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class TranscriptEntry(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
uuid: str
|
uuid: str
|
||||||
parentUuid: str | None
|
parentUuid: str | None
|
||||||
|
isCompactSummary: bool | None = None
|
||||||
message: dict[str, Any]
|
message: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@@ -53,6 +54,24 @@ class TranscriptBuilder:
|
|||||||
return self._entries[-1].message.get("id", "")
|
return self._entries[-1].message.get("id", "")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||||
|
"""Parse a single transcript entry, filtering strippable types.
|
||||||
|
|
||||||
|
Returns ``None`` for entries that should be skipped (strippable types
|
||||||
|
that are not compaction summaries).
|
||||||
|
"""
|
||||||
|
entry_type = data.get("type", "")
|
||||||
|
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||||
|
return None
|
||||||
|
return TranscriptEntry(
|
||||||
|
type=entry_type,
|
||||||
|
uuid=data.get("uuid") or str(uuid4()),
|
||||||
|
parentUuid=data.get("parentUuid"),
|
||||||
|
isCompactSummary=data.get("isCompactSummary") or None,
|
||||||
|
message=data.get("message", {}),
|
||||||
|
)
|
||||||
|
|
||||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||||
"""Load complete previous transcript.
|
"""Load complete previous transcript.
|
||||||
|
|
||||||
@@ -78,18 +97,9 @@ class TranscriptBuilder:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Load all non-strippable entries (user/assistant/system/etc.)
|
entry = self._parse_entry(data)
|
||||||
# Skip only STRIPPABLE_TYPES to match strip_progress_entries() behavior
|
if entry is None:
|
||||||
entry_type = data.get("type", "")
|
|
||||||
if entry_type in STRIPPABLE_TYPES:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
entry = TranscriptEntry(
|
|
||||||
type=data["type"],
|
|
||||||
uuid=data.get("uuid") or str(uuid4()),
|
|
||||||
parentUuid=data.get("parentUuid"),
|
|
||||||
message=data.get("message", {}),
|
|
||||||
)
|
|
||||||
self._entries.append(entry)
|
self._entries.append(entry)
|
||||||
self._last_uuid = entry.uuid
|
self._last_uuid = entry.uuid
|
||||||
|
|
||||||
@@ -162,6 +172,43 @@ class TranscriptBuilder:
|
|||||||
)
|
)
|
||||||
self._last_uuid = msg_uuid
|
self._last_uuid = msg_uuid
|
||||||
|
|
||||||
|
def replace_entries(
|
||||||
|
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||||
|
) -> None:
|
||||||
|
"""Replace all entries with compacted entries from the CLI session file.
|
||||||
|
|
||||||
|
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||||
|
CLI's active context (compaction summary + post-compaction entries).
|
||||||
|
|
||||||
|
Builds the new list first and validates it's non-empty before swapping,
|
||||||
|
so corrupt input cannot wipe the conversation history.
|
||||||
|
"""
|
||||||
|
new_entries: list[TranscriptEntry] = []
|
||||||
|
for data in compacted_entries:
|
||||||
|
entry = self._parse_entry(data)
|
||||||
|
if entry is not None:
|
||||||
|
new_entries.append(entry)
|
||||||
|
|
||||||
|
if not new_entries:
|
||||||
|
logger.warning(
|
||||||
|
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||||
|
log_prefix,
|
||||||
|
len(compacted_entries),
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
old_count = len(self._entries)
|
||||||
|
self._entries = new_entries
|
||||||
|
self._last_uuid = new_entries[-1].uuid
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||||
|
log_prefix,
|
||||||
|
old_count,
|
||||||
|
len(self._entries),
|
||||||
|
)
|
||||||
|
|
||||||
def to_jsonl(self) -> str:
|
def to_jsonl(self) -> str:
|
||||||
"""Export complete context as JSONL.
|
"""Export complete context as JSONL.
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,23 @@
|
|||||||
"""Unit tests for JSONL transcript management utilities."""
|
"""Unit tests for JSONL transcript management utilities."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
from .transcript import (
|
from .transcript import (
|
||||||
STRIPPABLE_TYPES,
|
STRIPPABLE_TYPES,
|
||||||
|
_cli_project_dir,
|
||||||
|
delete_transcript,
|
||||||
|
read_cli_session_file,
|
||||||
|
read_compacted_entries,
|
||||||
strip_progress_entries,
|
strip_progress_entries,
|
||||||
validate_transcript,
|
validate_transcript,
|
||||||
write_transcript_to_tempfile,
|
write_transcript_to_tempfile,
|
||||||
)
|
)
|
||||||
|
from .transcript_builder import TranscriptBuilder
|
||||||
|
|
||||||
|
|
||||||
def _make_jsonl(*entries: dict) -> str:
|
def _make_jsonl(*entries: dict) -> str:
|
||||||
@@ -282,3 +290,610 @@ class TestStripProgressEntries:
|
|||||||
lines = result.strip().split("\n")
|
lines = result.strip().split("\n")
|
||||||
asst_entry = json.loads(lines[-1])
|
asst_entry = json.loads(lines[-1])
|
||||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||||
|
|
||||||
|
|
||||||
|
# --- read_cli_session_file ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadCliSessionFile:
|
||||||
|
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""read_cli_session_file returns None when no .jsonl files exist."""
|
||||||
|
# Create a project dir with no jsonl files
|
||||||
|
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||||
|
project_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||||
|
lambda sdk_cwd: str(project_dir),
|
||||||
|
)
|
||||||
|
assert read_cli_session_file("/fake/cwd") is None
|
||||||
|
|
||||||
|
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
|
||||||
|
"""read_cli_session_file returns the content of a single .jsonl file."""
|
||||||
|
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||||
|
project_dir.mkdir(parents=True)
|
||||||
|
jsonl_file = project_dir / "session.jsonl"
|
||||||
|
jsonl_file.write_text("line1\nline2\n")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||||
|
lambda sdk_cwd: str(project_dir),
|
||||||
|
)
|
||||||
|
result = read_cli_session_file("/fake/cwd")
|
||||||
|
assert result == "line1\nline2\n"
|
||||||
|
|
||||||
|
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
|
||||||
|
"""read_cli_session_file skips symlinks that escape the project dir."""
|
||||||
|
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||||
|
project_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create a file outside the project dir
|
||||||
|
outside = tmp_path / "outside"
|
||||||
|
outside.mkdir()
|
||||||
|
outside_file = outside / "evil.jsonl"
|
||||||
|
outside_file.write_text("should not be read\n")
|
||||||
|
|
||||||
|
# Symlink from inside project_dir to outside file
|
||||||
|
symlink = project_dir / "evil.jsonl"
|
||||||
|
symlink.symlink_to(outside_file)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||||
|
lambda sdk_cwd: str(project_dir),
|
||||||
|
)
|
||||||
|
# The symlink target resolves outside project_dir, so it should be skipped
|
||||||
|
result = read_cli_session_file("/fake/cwd")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- _cli_project_dir ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliProjectDir:
|
||||||
|
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||||
|
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
config_dir.mkdir()
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
projects_dir.mkdir()
|
||||||
|
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
# Create a symlink inside projects/ that points outside of it.
|
||||||
|
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
|
||||||
|
# cwd whose encoded form matches the symlink name we create.
|
||||||
|
evil_target = tmp_path / "escaped"
|
||||||
|
evil_target.mkdir()
|
||||||
|
|
||||||
|
# The encoded form of "/evil/cwd" is "-evil-cwd"
|
||||||
|
symlink_path = projects_dir / "-evil-cwd"
|
||||||
|
symlink_path.symlink_to(evil_target)
|
||||||
|
|
||||||
|
result = _cli_project_dir("/evil/cwd")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- delete_transcript ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteTranscript:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deletes_both_jsonl_and_meta(self):
|
||||||
|
"""delete_transcript removes both the .jsonl and .meta.json files."""
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.delete = AsyncMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.workspace_storage.get_workspace_storage",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_storage,
|
||||||
|
):
|
||||||
|
await delete_transcript("user-123", "session-456")
|
||||||
|
|
||||||
|
assert mock_storage.delete.call_count == 2
|
||||||
|
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||||
|
assert any(p.endswith(".jsonl") for p in paths)
|
||||||
|
assert any(p.endswith(".meta.json") for p in paths)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_continues_on_jsonl_delete_failure(self):
|
||||||
|
"""If .jsonl delete fails, .meta.json delete is still attempted."""
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.delete = AsyncMock(
|
||||||
|
side_effect=[Exception("jsonl delete failed"), None]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.workspace_storage.get_workspace_storage",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_storage,
|
||||||
|
):
|
||||||
|
# Should not raise
|
||||||
|
await delete_transcript("user-123", "session-456")
|
||||||
|
|
||||||
|
assert mock_storage.delete.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_meta_delete_failure(self):
|
||||||
|
"""If .meta.json delete fails, no exception propagates."""
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.delete = AsyncMock(
|
||||||
|
side_effect=[None, Exception("meta delete failed")]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.workspace_storage.get_workspace_storage",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_storage,
|
||||||
|
):
|
||||||
|
# Should not raise
|
||||||
|
await delete_transcript("user-123", "session-456")
|
||||||
|
|
||||||
|
|
||||||
|
# --- read_compacted_entries ---
|
||||||
|
|
||||||
|
|
||||||
|
COMPACT_SUMMARY = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "assistant", "content": "compacted context"},
|
||||||
|
}
|
||||||
|
POST_COMPACT_ASST = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a2",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "response after compaction"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadCompactedEntries:
|
||||||
|
def test_returns_summary_and_entries_after(self, tmp_path, monkeypatch):
|
||||||
|
"""File with isCompactSummary entry returns summary + entries after."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
pre_compact = {"type": "user", "uuid": "u1", "message": {"role": "user"}}
|
||||||
|
path = session_dir / "session.jsonl"
|
||||||
|
path.write_text(_make_jsonl(pre_compact, COMPACT_SUMMARY, POST_COMPACT_ASST))
|
||||||
|
|
||||||
|
result = read_compacted_entries(str(path))
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["isCompactSummary"] is True
|
||||||
|
assert result[1]["uuid"] == "a2"
|
||||||
|
|
||||||
|
def test_no_compact_summary_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""File without isCompactSummary returns None."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
path = session_dir / "session.jsonl"
|
||||||
|
path.write_text(_make_jsonl(USER_MSG, ASST_MSG))
|
||||||
|
|
||||||
|
result = read_compacted_entries(str(path))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_file_not_found_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Non-existent file returns None."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
projects_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
result = read_compacted_entries(str(projects_dir / "missing.jsonl"))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_empty_path_returns_none(self):
|
||||||
|
"""Empty string path returns None."""
|
||||||
|
result = read_compacted_entries("")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_malformed_json_lines_skipped(self, tmp_path, monkeypatch):
|
||||||
|
"""Malformed JSON lines are skipped gracefully."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
path = session_dir / "session.jsonl"
|
||||||
|
content = "not valid json\n" + json.dumps(COMPACT_SUMMARY) + "\n"
|
||||||
|
content += "also bad\n" + json.dumps(POST_COMPACT_ASST) + "\n"
|
||||||
|
path.write_text(content)
|
||||||
|
|
||||||
|
result = read_compacted_entries(str(path))
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 2 # summary + post-compact assistant
|
||||||
|
|
||||||
|
def test_multiple_compact_summaries_uses_last(self, tmp_path, monkeypatch):
|
||||||
|
"""When multiple isCompactSummary entries exist, uses the last one
|
||||||
|
(most recent compaction)."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
second_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs2",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "assistant", "content": "second summary"},
|
||||||
|
}
|
||||||
|
path = session_dir / "session.jsonl"
|
||||||
|
path.write_text(_make_jsonl(COMPACT_SUMMARY, POST_COMPACT_ASST, second_summary))
|
||||||
|
|
||||||
|
result = read_compacted_entries(str(path))
|
||||||
|
assert result is not None
|
||||||
|
# Last summary found, so only cs2 returned
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["uuid"] == "cs2"
|
||||||
|
|
||||||
|
def test_path_outside_projects_base_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Transcript path outside the projects directory is rejected."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
(config_dir / "projects").mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
evil_file = tmp_path / "evil.jsonl"
|
||||||
|
evil_file.write_text(_make_jsonl(COMPACT_SUMMARY))
|
||||||
|
|
||||||
|
result = read_compacted_entries(str(evil_file))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- TranscriptBuilder.replace_entries ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptBuilderReplaceEntries:
|
||||||
|
def test_replaces_existing_entries(self):
|
||||||
|
"""replace_entries replaces all entries with compacted ones."""
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hello")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "world"}])
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
compacted = [
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "compacted summary"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "response"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert entries[0]["uuid"] == "cs1"
|
||||||
|
assert entries[1]["uuid"] == "a1"
|
||||||
|
|
||||||
|
def test_filters_strippable_types(self):
|
||||||
|
"""Strippable types are filtered out during replace."""
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
compacted = [
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"message": {"role": "user", "content": "compacted summary"},
|
||||||
|
},
|
||||||
|
{"type": "progress", "uuid": "p1", "message": {}},
|
||||||
|
{"type": "summary", "uuid": "s1", "message": {}},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "hi"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
assert builder.entry_count == 2 # progress and summary were filtered
|
||||||
|
|
||||||
|
def test_maintains_last_uuid_chain(self):
|
||||||
|
"""After replace, _last_uuid is the last entry's uuid."""
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
compacted = [
|
||||||
|
{
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"message": {"role": "user", "content": "compacted summary"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "hi"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
# Appending a new user message should chain to a1
|
||||||
|
builder.append_user("next question")
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert entries[-1]["parentUuid"] == "a1"
|
||||||
|
|
||||||
|
def test_empty_entries_list_keeps_existing(self):
|
||||||
|
"""Replacing with empty list keeps existing entries (safety check)."""
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("hello")
|
||||||
|
builder.replace_entries([])
|
||||||
|
# Empty input is treated as corrupt — existing entries preserved
|
||||||
|
assert builder.entry_count == 1
|
||||||
|
assert not builder.is_empty
|
||||||
|
|
||||||
|
|
||||||
|
# --- TranscriptBuilder.load_previous with compacted content ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptBuilderLoadPreviousCompacted:
|
||||||
|
def test_preserves_compact_summary_entry(self):
|
||||||
|
"""load_previous preserves isCompactSummary entries even though
|
||||||
|
their type is 'summary' (which is in STRIPPABLE_TYPES)."""
|
||||||
|
compacted_content = _make_jsonl(COMPACT_SUMMARY, POST_COMPACT_ASST)
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(compacted_content)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert entries[0]["type"] == "summary"
|
||||||
|
assert entries[0]["uuid"] == "cs1"
|
||||||
|
assert entries[1]["uuid"] == "a2"
|
||||||
|
|
||||||
|
def test_strips_regular_summary_entries(self):
|
||||||
|
"""Regular summary entries (without isCompactSummary) are still stripped."""
|
||||||
|
regular_summary = {"type": "summary", "uuid": "s1", "message": {"content": "x"}}
|
||||||
|
content = _make_jsonl(regular_summary, POST_COMPACT_ASST)
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(content)
|
||||||
|
assert builder.entry_count == 1 # Only the assistant entry
|
||||||
|
|
||||||
|
|
||||||
|
# --- End-to-end compaction flow (simulates service.py) ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionFlowIntegration:
|
||||||
|
"""Simulate the full compaction flow as it happens in service.py:
|
||||||
|
|
||||||
|
1. TranscriptBuilder loads a previous transcript (download)
|
||||||
|
2. New messages are appended (user query + assistant response)
|
||||||
|
3. CompactionTracker fires (PreCompact hook → emit_start → emit_end)
|
||||||
|
4. read_compacted_entries reads the CLI session file
|
||||||
|
5. TranscriptBuilder.replace_entries syncs with CLI state
|
||||||
|
6. Final to_jsonl() produces the correct output (upload)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_full_compaction_roundtrip(self, tmp_path, monkeypatch):
|
||||||
|
"""Full roundtrip: load → append → compact → replace → export."""
|
||||||
|
# Setup: create a CLI session file with pre-compact + compaction entries
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
# Simulate a transcript with old messages, then a compaction summary
|
||||||
|
old_user = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "old question"},
|
||||||
|
}
|
||||||
|
old_asst = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "u1",
|
||||||
|
"message": {"role": "assistant", "content": "old answer"},
|
||||||
|
}
|
||||||
|
compact_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "compacted summary of conversation"},
|
||||||
|
}
|
||||||
|
post_compact_asst = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a2",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "response after compaction"},
|
||||||
|
}
|
||||||
|
session_file = session_dir / "session.jsonl"
|
||||||
|
session_file.write_text(
|
||||||
|
_make_jsonl(old_user, old_asst, compact_summary, post_compact_asst)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: TranscriptBuilder loads previous transcript (simulates download)
|
||||||
|
# The previous transcript would have the OLD entries (pre-compaction)
|
||||||
|
previous_transcript = _make_jsonl(old_user, old_asst)
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(previous_transcript)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# Step 2: New messages appended during the current query
|
||||||
|
builder.append_user("new question")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "new answer"}])
|
||||||
|
assert builder.entry_count == 4
|
||||||
|
|
||||||
|
# Step 3: read_compacted_entries reads the CLI session file
|
||||||
|
compacted = read_compacted_entries(str(session_file))
|
||||||
|
assert compacted is not None
|
||||||
|
assert len(compacted) == 2 # compact_summary + post_compact_asst
|
||||||
|
assert compacted[0]["isCompactSummary"] is True
|
||||||
|
|
||||||
|
# Step 4: replace_entries syncs builder with CLI state
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
assert builder.entry_count == 2 # Only compacted entries now
|
||||||
|
|
||||||
|
# Step 5: Append post-compaction messages (continuing the stream)
|
||||||
|
builder.append_user("follow-up question")
|
||||||
|
assert builder.entry_count == 3
|
||||||
|
|
||||||
|
# Step 6: Export and verify
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert len(entries) == 3
|
||||||
|
# First entry is the compaction summary
|
||||||
|
assert entries[0]["type"] == "summary"
|
||||||
|
assert entries[0]["uuid"] == "cs1"
|
||||||
|
# Second is the post-compact assistant
|
||||||
|
assert entries[1]["uuid"] == "a2"
|
||||||
|
# Third is our follow-up, parented to the last compacted entry
|
||||||
|
assert entries[2]["type"] == "user"
|
||||||
|
assert entries[2]["parentUuid"] == "a2"
|
||||||
|
|
||||||
|
def test_compaction_preserves_chain_across_multiple_compactions(
|
||||||
|
self, tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Two compactions: first compacts old history, second compacts the first."""
|
||||||
|
config_dir = tmp_path / "config"
|
||||||
|
projects_dir = config_dir / "projects"
|
||||||
|
session_dir = projects_dir / "proj"
|
||||||
|
session_dir.mkdir(parents=True)
|
||||||
|
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||||
|
|
||||||
|
# First compaction
|
||||||
|
first_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "first summary"},
|
||||||
|
}
|
||||||
|
mid_asst = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "mid response"},
|
||||||
|
}
|
||||||
|
# Second compaction (compacts the first summary + mid_asst)
|
||||||
|
second_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs2",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "second summary"},
|
||||||
|
}
|
||||||
|
final_asst = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a2",
|
||||||
|
"parentUuid": "cs2",
|
||||||
|
"message": {"role": "assistant", "content": "final response"},
|
||||||
|
}
|
||||||
|
|
||||||
|
session_file = session_dir / "session.jsonl"
|
||||||
|
session_file.write_text(
|
||||||
|
_make_jsonl(first_summary, mid_asst, second_summary, final_asst)
|
||||||
|
)
|
||||||
|
|
||||||
|
# read_compacted_entries should find the LAST summary
|
||||||
|
compacted = read_compacted_entries(str(session_file))
|
||||||
|
assert compacted is not None
|
||||||
|
assert len(compacted) == 2 # second_summary + final_asst
|
||||||
|
assert compacted[0]["uuid"] == "cs2"
|
||||||
|
|
||||||
|
# Apply to builder
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.append_user("old stuff")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "old response"}])
|
||||||
|
builder.replace_entries(compacted)
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# New message chains correctly
|
||||||
|
builder.append_user("after second compaction")
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert entries[-1]["parentUuid"] == "a2"
|
||||||
|
|
||||||
|
def test_strip_progress_preserves_compact_summaries(self):
|
||||||
|
"""strip_progress_entries doesn't strip isCompactSummary entries
|
||||||
|
even though their type is 'summary' (in STRIPPABLE_TYPES)."""
|
||||||
|
compact_summary = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "compacted"},
|
||||||
|
}
|
||||||
|
regular_summary = {"type": "summary", "uuid": "s1", "message": {"content": "x"}}
|
||||||
|
progress = {"type": "progress", "uuid": "p1", "data": {"stdout": "..."}}
|
||||||
|
user = {
|
||||||
|
"type": "user",
|
||||||
|
"uuid": "u1",
|
||||||
|
"message": {"role": "user", "content": "hi"},
|
||||||
|
}
|
||||||
|
|
||||||
|
content = _make_jsonl(compact_summary, regular_summary, progress, user)
|
||||||
|
stripped = strip_progress_entries(content)
|
||||||
|
stripped_entries = [
|
||||||
|
json.loads(line) for line in stripped.strip().split("\n") if line.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
uuids = [e.get("uuid") for e in stripped_entries]
|
||||||
|
# compact_summary kept, regular_summary stripped, progress stripped, user kept
|
||||||
|
assert "cs1" in uuids # compact summary preserved
|
||||||
|
assert "s1" not in uuids # regular summary stripped
|
||||||
|
assert "p1" not in uuids # progress stripped
|
||||||
|
assert "u1" in uuids # user kept
|
||||||
|
|
||||||
|
def test_builder_load_then_replace_then_export_roundtrip(self):
|
||||||
|
"""Load a compacted transcript, replace with new compaction, export.
|
||||||
|
Simulates two consecutive turns with compaction each time."""
|
||||||
|
# Turn 1: load compacted transcript
|
||||||
|
compact1 = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs1",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "summary v1"},
|
||||||
|
}
|
||||||
|
asst1 = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a1",
|
||||||
|
"parentUuid": "cs1",
|
||||||
|
"message": {"role": "assistant", "content": "response 1"},
|
||||||
|
}
|
||||||
|
builder = TranscriptBuilder()
|
||||||
|
builder.load_previous(_make_jsonl(compact1, asst1))
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# Turn 1: append new messages
|
||||||
|
builder.append_user("question")
|
||||||
|
builder.append_assistant([{"type": "text", "text": "answer"}])
|
||||||
|
assert builder.entry_count == 4
|
||||||
|
|
||||||
|
# Turn 1: compaction fires — replace with new compacted state
|
||||||
|
compact2 = {
|
||||||
|
"type": "summary",
|
||||||
|
"uuid": "cs2",
|
||||||
|
"isCompactSummary": True,
|
||||||
|
"message": {"role": "user", "content": "summary v2"},
|
||||||
|
}
|
||||||
|
asst2 = {
|
||||||
|
"type": "assistant",
|
||||||
|
"uuid": "a2",
|
||||||
|
"parentUuid": "cs2",
|
||||||
|
"message": {"role": "assistant", "content": "continuing"},
|
||||||
|
}
|
||||||
|
builder.replace_entries([compact2, asst2])
|
||||||
|
assert builder.entry_count == 2
|
||||||
|
|
||||||
|
# Export (this goes to cloud storage for next turn's download)
|
||||||
|
output = builder.to_jsonl()
|
||||||
|
lines = [json.loads(line) for line in output.strip().split("\n")]
|
||||||
|
assert lines[0]["uuid"] == "cs2"
|
||||||
|
assert lines[0]["type"] == "summary"
|
||||||
|
assert lines[1]["uuid"] == "a2"
|
||||||
|
|
||||||
|
# Turn 2: fresh builder loads the exported transcript
|
||||||
|
builder2 = TranscriptBuilder()
|
||||||
|
builder2.load_previous(output)
|
||||||
|
assert builder2.entry_count == 2
|
||||||
|
builder2.append_user("turn 2 question")
|
||||||
|
output2 = builder2.to_jsonl()
|
||||||
|
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||||
|
assert lines2[-1]["parentUuid"] == "a2"
|
||||||
|
|||||||
@@ -28,10 +28,24 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
||||||
|
_client: LangfuseAsyncOpenAI | None = None
|
||||||
|
_langfuse = None
|
||||||
|
|
||||||
|
|
||||||
langfuse = get_client()
|
def _get_openai_client() -> LangfuseAsyncOpenAI:
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def _get_langfuse():
|
||||||
|
global _langfuse
|
||||||
|
if _langfuse is None:
|
||||||
|
_langfuse = get_client()
|
||||||
|
return _langfuse
|
||||||
|
|
||||||
|
|
||||||
# Default system prompt used when Langfuse is not configured
|
# Default system prompt used when Langfuse is not configured
|
||||||
# Provides minimal baseline tone and personality - all workflow, tools, and
|
# Provides minimal baseline tone and personality - all workflow, tools, and
|
||||||
@@ -84,7 +98,7 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
else "latest"
|
else "latest"
|
||||||
)
|
)
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt,
|
_get_langfuse().get_prompt,
|
||||||
config.langfuse_prompt_name,
|
config.langfuse_prompt_name,
|
||||||
label=label,
|
label=label,
|
||||||
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
||||||
@@ -158,7 +172,7 @@ async def _generate_session_title(
|
|||||||
"environment": settings.config.app_env.value,
|
"environment": settings.config.app_env.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await _get_openai_client().chat.completions.create(
|
||||||
model=config.title_model,
|
model=config.title_model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
|
from backend.api.model import CopilotCompletionPayload
|
||||||
|
from backend.data.notification_bus import (
|
||||||
|
AsyncRedisNotificationEventBus,
|
||||||
|
NotificationEvent,
|
||||||
|
)
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
@@ -38,6 +43,7 @@ from .response_model import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
_notification_bus = AsyncRedisNotificationEventBus()
|
||||||
|
|
||||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||||
_local_sessions: dict[str, asyncio.Task] = {}
|
_local_sessions: dict[str, asyncio.Task] = {}
|
||||||
@@ -745,6 +751,29 @@ async def mark_session_completed(
|
|||||||
|
|
||||||
# Clean up local session reference if exists
|
# Clean up local session reference if exists
|
||||||
_local_sessions.pop(session_id, None)
|
_local_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
# Publish copilot completion notification via WebSocket
|
||||||
|
if meta:
|
||||||
|
parsed = _parse_session_meta(meta, session_id)
|
||||||
|
if parsed.user_id:
|
||||||
|
try:
|
||||||
|
await _notification_bus.publish(
|
||||||
|
NotificationEvent(
|
||||||
|
user_id=parsed.user_id,
|
||||||
|
payload=CopilotCompletionPayload(
|
||||||
|
type="copilot_completion",
|
||||||
|
event="session_completed",
|
||||||
|
session_id=session_id,
|
||||||
|
status=status,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to publish copilot completion notification "
|
||||||
|
f"for session {session_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .bash_exec import BashExecTool
|
from .bash_exec import BashExecTool
|
||||||
|
from .continue_run_block import ContinueRunBlockTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
@@ -19,7 +20,10 @@ from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsToo
|
|||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
from .find_library_agent import FindLibraryAgentTool
|
from .find_library_agent import FindLibraryAgentTool
|
||||||
|
from .fix_agent import FixAgentGraphTool
|
||||||
|
from .get_agent_building_guide import GetAgentBuildingGuideTool
|
||||||
from .get_doc_page import GetDocPageTool
|
from .get_doc_page import GetDocPageTool
|
||||||
|
from .get_mcp_guide import GetMCPGuideTool
|
||||||
from .manage_folders import (
|
from .manage_folders import (
|
||||||
CreateFolderTool,
|
CreateFolderTool,
|
||||||
DeleteFolderTool,
|
DeleteFolderTool,
|
||||||
@@ -32,6 +36,7 @@ from .run_agent import RunAgentTool
|
|||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .run_mcp_tool import RunMCPToolTool
|
from .run_mcp_tool import RunMCPToolTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
from .validate_agent import ValidateAgentGraphTool
|
||||||
from .web_fetch import WebFetchTool
|
from .web_fetch import WebFetchTool
|
||||||
from .workspace_files import (
|
from .workspace_files import (
|
||||||
DeleteWorkspaceFileTool,
|
DeleteWorkspaceFileTool,
|
||||||
@@ -64,10 +69,13 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||||
"run_agent": RunAgentTool(),
|
"run_agent": RunAgentTool(),
|
||||||
"run_block": RunBlockTool(),
|
"run_block": RunBlockTool(),
|
||||||
|
"continue_run_block": ContinueRunBlockTool(),
|
||||||
"run_mcp_tool": RunMCPToolTool(),
|
"run_mcp_tool": RunMCPToolTool(),
|
||||||
|
"get_mcp_guide": GetMCPGuideTool(),
|
||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
|
"get_agent_building_guide": GetAgentBuildingGuideTool(),
|
||||||
# Web fetch for safe URL retrieval
|
# Web fetch for safe URL retrieval
|
||||||
"web_fetch": WebFetchTool(),
|
"web_fetch": WebFetchTool(),
|
||||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||||
@@ -80,6 +88,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
# Feature request tools
|
# Feature request tools
|
||||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||||
"create_feature_request": CreateFeatureRequestTool(),
|
"create_feature_request": CreateFeatureRequestTool(),
|
||||||
|
# Agent generation tools (local validation/fixing)
|
||||||
|
"validate_agent_graph": ValidateAgentGraphTool(),
|
||||||
|
"fix_agent_graph": FixAgentGraphTool(),
|
||||||
# Workspace tools for CoPilot file operations
|
# Workspace tools for CoPilot file operations
|
||||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
|
|||||||
@@ -32,8 +32,9 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.copilot.context import get_workspace_manager
|
||||||
from backend.copilot.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.util.request import validate_url
|
from backend.util.request import validate_url_host
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -43,7 +44,6 @@ from .models import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
from .workspace_files import get_manager
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ async def _save_browser_state(
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
manager = await get_manager(user_id, session.session_id)
|
manager = await get_workspace_manager(user_id, session.session_id)
|
||||||
await manager.write_file(
|
await manager.write_file(
|
||||||
content=json.dumps(state).encode("utf-8"),
|
content=json.dumps(state).encode("utf-8"),
|
||||||
filename=_STATE_FILENAME,
|
filename=_STATE_FILENAME,
|
||||||
@@ -218,7 +218,7 @@ async def _restore_browser_state(
|
|||||||
Returns True on success (or no state to restore), False on failure.
|
Returns True on success (or no state to restore), False on failure.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
manager = await get_manager(user_id, session.session_id)
|
manager = await get_workspace_manager(user_id, session.session_id)
|
||||||
|
|
||||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||||
if file_info is None:
|
if file_info is None:
|
||||||
@@ -235,7 +235,7 @@ async def _restore_browser_state(
|
|||||||
if url:
|
if url:
|
||||||
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
# Validate the saved URL to prevent SSRF via stored redirect targets.
|
||||||
try:
|
try:
|
||||||
await validate_url(url, trusted_origins=[])
|
await validate_url_host(url)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
"[browser] State restore: blocked SSRF URL %s", url[:200]
|
||||||
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
|
|||||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||||
if user_id:
|
if user_id:
|
||||||
try:
|
try:
|
||||||
manager = await get_manager(user_id, session_name)
|
manager = await get_workspace_manager(user_id, session_name)
|
||||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||||
if file_info is not None:
|
if file_info is not None:
|
||||||
await manager.delete_file(file_info.id)
|
await manager.delete_file(file_info.id)
|
||||||
@@ -473,7 +473,7 @@ class BrowserNavigateTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await validate_url(url, trusted_origins=[])
|
await validate_url_host(url)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=str(e),
|
message=str(e),
|
||||||
|
|||||||
@@ -68,17 +68,18 @@ def _run_result(rc: int = 0, stdout: str = "", stderr: str = "") -> tuple:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# SSRF protection via shared validate_url (backend.util.request)
|
# SSRF protection via shared validate_url_host (backend.util.request)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Patch target: validate_url is imported directly into agent_browser's module scope.
|
# Patch target: validate_url_host is imported directly into agent_browser's
|
||||||
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url"
|
# module scope.
|
||||||
|
_VALIDATE_URL = "backend.copilot.tools.agent_browser.validate_url_host"
|
||||||
|
|
||||||
|
|
||||||
class TestSsrfViaValidateUrl:
|
class TestSsrfViaValidateUrl:
|
||||||
"""Verify that browser_navigate uses validate_url for SSRF protection.
|
"""Verify that browser_navigate uses validate_url_host for SSRF protection.
|
||||||
|
|
||||||
We mock validate_url itself (not the low-level socket) so these tests
|
We mock validate_url_host itself (not the low-level socket) so these tests
|
||||||
exercise the integration point, not the internals of request.py
|
exercise the integration point, not the internals of request.py
|
||||||
(which has its own thorough test suite in request_test.py).
|
(which has its own thorough test suite in request_test.py).
|
||||||
"""
|
"""
|
||||||
@@ -89,7 +90,7 @@ class TestSsrfViaValidateUrl:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_blocked_ip_returns_blocked_url_error(self):
|
async def test_blocked_ip_returns_blocked_url_error(self):
|
||||||
"""validate_url raises ValueError → tool returns blocked_url ErrorResponse."""
|
"""validate_url_host raises ValueError → tool returns blocked_url ErrorResponse."""
|
||||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||||
mock_validate.side_effect = ValueError(
|
mock_validate.side_effect = ValueError(
|
||||||
"Access to blocked IP 10.0.0.1 is not allowed."
|
"Access to blocked IP 10.0.0.1 is not allowed."
|
||||||
@@ -124,8 +125,8 @@ class TestSsrfViaValidateUrl:
|
|||||||
assert result.error == "blocked_url"
|
assert result.error == "blocked_url"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_validate_url_called_with_empty_trusted_origins(self):
|
async def test_validate_url_host_called_without_trusted_hostnames(self):
|
||||||
"""Confirms no trusted-origins bypass is granted — all URLs are validated."""
|
"""Confirms no trusted-hostnames bypass is granted — all URLs are validated."""
|
||||||
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
with patch(_VALIDATE_URL, new_callable=AsyncMock) as mock_validate:
|
||||||
mock_validate.return_value = (object(), False, ["1.2.3.4"])
|
mock_validate.return_value = (object(), False, ["1.2.3.4"])
|
||||||
with patch(
|
with patch(
|
||||||
@@ -143,7 +144,7 @@ class TestSsrfViaValidateUrl:
|
|||||||
session=self.session,
|
session=self.session,
|
||||||
url="https://example.com",
|
url="https://example.com",
|
||||||
)
|
)
|
||||||
mock_validate.assert_called_once_with("https://example.com", trusted_origins=[])
|
mock_validate.assert_called_once_with("https://example.com")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -896,7 +897,7 @@ class TestHasLocalSession:
|
|||||||
# _save_browser_state
|
# _save_browser_state
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
|
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
|
||||||
|
|
||||||
|
|
||||||
def _make_mock_manager():
|
def _make_mock_manager():
|
||||||
|
|||||||
@@ -1,20 +1,15 @@
|
|||||||
"""Agent generator package - Creates agents from natural language."""
|
"""Agent generator package - Creates agents from natural language."""
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
AgentJsonValidationError,
|
AgentJsonValidationError,
|
||||||
AgentSummary,
|
AgentSummary,
|
||||||
DecompositionResult,
|
DecompositionResult,
|
||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
customize_template,
|
|
||||||
decompose_goal,
|
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
extract_uuids_from_text,
|
extract_uuids_from_text,
|
||||||
generate_agent,
|
|
||||||
generate_agent_patch,
|
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
get_all_relevant_agents_for_generation,
|
get_all_relevant_agents_for_generation,
|
||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
@@ -27,25 +22,20 @@ from .core import (
|
|||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
)
|
)
|
||||||
from .errors import get_user_message_for_error
|
from .errors import get_user_message_for_error
|
||||||
from .service import health_check as check_external_service_health
|
from .validation import AgentFixer, AgentValidator
|
||||||
from .service import is_external_service_configured
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentGeneratorNotConfiguredError",
|
"AgentFixer",
|
||||||
|
"AgentValidator",
|
||||||
"AgentJsonValidationError",
|
"AgentJsonValidationError",
|
||||||
"AgentSummary",
|
"AgentSummary",
|
||||||
"DecompositionResult",
|
"DecompositionResult",
|
||||||
"DecompositionStep",
|
"DecompositionStep",
|
||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
"check_external_service_health",
|
|
||||||
"customize_template",
|
|
||||||
"decompose_goal",
|
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
"extract_uuids_from_text",
|
"extract_uuids_from_text",
|
||||||
"generate_agent",
|
|
||||||
"generate_agent_patch",
|
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
"get_all_relevant_agents_for_generation",
|
"get_all_relevant_agents_for_generation",
|
||||||
"get_library_agent_by_graph_id",
|
"get_library_agent_by_graph_id",
|
||||||
@@ -54,7 +44,6 @@ __all__ = [
|
|||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
"graph_to_json",
|
"graph_to_json",
|
||||||
"is_external_service_configured",
|
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
"search_marketplace_agents_for_generation",
|
"search_marketplace_agents_for_generation",
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""Block management for agent generation.
|
||||||
|
|
||||||
|
Provides cached access to block metadata for validation and fixing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
from backend.blocks import get_blocks as get_block_classes
|
||||||
|
from backend.blocks._base import Block
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["get_blocks_as_dicts", "reset_block_caches"]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level caches
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_blocks_cache: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def reset_block_caches() -> None:
|
||||||
|
"""Reset all module-level caches (useful after updating block descriptions)."""
|
||||||
|
global _blocks_cache
|
||||||
|
_blocks_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. get_blocks_as_dicts
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocks_as_dicts() -> list[dict[str, Any]]:
|
||||||
|
"""Get all available blocks as dicts (cached after first call).
|
||||||
|
|
||||||
|
Each dict contains the keys returned by ``Block.get_info().model_dump()``:
|
||||||
|
id, name, description, inputSchema, outputSchema, categories,
|
||||||
|
staticOutput, costs, contributors, uiType.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of block info dicts.
|
||||||
|
"""
|
||||||
|
global _blocks_cache
|
||||||
|
if _blocks_cache is not None:
|
||||||
|
return _blocks_cache
|
||||||
|
|
||||||
|
block_classes: dict[str, Type[Block]] = get_block_classes() # type: ignore[assignment]
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
for block_cls in block_classes.values():
|
||||||
|
try:
|
||||||
|
instance = block_cls()
|
||||||
|
info = instance.get_info().model_dump()
|
||||||
|
# Use optimized description if available (loaded at startup)
|
||||||
|
if instance.optimized_description:
|
||||||
|
info["description"] = instance.optimized_description
|
||||||
|
blocks.append(info)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to load block info for %s, skipping",
|
||||||
|
getattr(block_cls, "__name__", "unknown"),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_blocks_cache = blocks
|
||||||
|
logger.info("Cached %d block dicts", len(blocks))
|
||||||
|
return _blocks_cache
|
||||||
@@ -10,13 +10,7 @@ from backend.data.db_accessors import graph_db, library_db, store_db
|
|||||||
from backend.data.graph import Graph, Link, Node
|
from backend.data.graph import Graph, Link, Node
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .helpers import UUID_RE_STR
|
||||||
customize_template_external,
|
|
||||||
decompose_goal_external,
|
|
||||||
generate_agent_external,
|
|
||||||
generate_agent_patch_external,
|
|
||||||
is_external_service_configured,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -78,38 +72,7 @@ class DecompositionResult(TypedDict, total=False):
|
|||||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
def _to_dict_list(
|
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
|
||||||
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
|
|
||||||
) -> list[dict[str, Any]] | None:
|
|
||||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
|
||||||
if agents is None:
|
|
||||||
return None
|
|
||||||
return [dict(a) for a in agents]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentGeneratorNotConfiguredError(Exception):
|
|
||||||
"""Raised when the external Agent Generator service is not configured."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _check_service_configured() -> None:
|
|
||||||
"""Check if the external Agent Generator service is configured.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the service is not configured.
|
|
||||||
"""
|
|
||||||
if not is_external_service_configured():
|
|
||||||
raise AgentGeneratorNotConfiguredError(
|
|
||||||
"Agent Generator service is not configured. "
|
|
||||||
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_UUID_PATTERN = re.compile(
|
|
||||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_uuids_from_text(text: str) -> list[str]:
|
def extract_uuids_from_text(text: str) -> list[str]:
|
||||||
@@ -553,69 +516,6 @@ async def enrich_library_agents_from_steps(
|
|||||||
return all_agents
|
return all_agents
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(
|
|
||||||
description: str,
|
|
||||||
context: str = "",
|
|
||||||
library_agents: Sequence[AgentSummary] | None = None,
|
|
||||||
) -> DecompositionResult | None:
|
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
description: Natural language goal description
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
library_agents: User's library agents available for sub-agent composition
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DecompositionResult with either:
|
|
||||||
- {"type": "clarifying_questions", "questions": [...]}
|
|
||||||
- {"type": "instructions", "steps": [...]}
|
|
||||||
Or None on error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
|
||||||
result = await decompose_goal_external(
|
|
||||||
description, context, _to_dict_list(library_agents)
|
|
||||||
)
|
|
||||||
return result # type: ignore[return-value]
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(
|
|
||||||
instructions: DecompositionResult | dict[str, Any],
|
|
||||||
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Generate agent JSON from instructions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instructions: Structured instructions from decompose_goal
|
|
||||||
library_agents: User's library agents available for sub-agent composition
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
|
||||||
result = await generate_agent_external(
|
|
||||||
dict(instructions), _to_dict_list(library_agents)
|
|
||||||
)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
|
||||||
return result
|
|
||||||
if "id" not in result:
|
|
||||||
result["id"] = str(uuid.uuid4())
|
|
||||||
if "version" not in result:
|
|
||||||
result["version"] = 1
|
|
||||||
if "is_active" not in result:
|
|
||||||
result["is_active"] = True
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AgentJsonValidationError(Exception):
|
class AgentJsonValidationError(Exception):
|
||||||
"""Raised when agent JSON is invalid or missing required fields."""
|
"""Raised when agent JSON is invalid or missing required fields."""
|
||||||
|
|
||||||
@@ -792,70 +692,3 @@ async def get_agent_as_json(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return graph_to_json(graph)
|
return graph_to_json(graph)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
|
||||||
update_request: str,
|
|
||||||
current_agent: dict[str, Any],
|
|
||||||
library_agents: Sequence[AgentSummary] | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Update an existing agent using natural language.
|
|
||||||
|
|
||||||
The external Agent Generator service handles:
|
|
||||||
- Generating the patch
|
|
||||||
- Applying the patch
|
|
||||||
- Fixing and validating the result
|
|
||||||
|
|
||||||
Args:
|
|
||||||
update_request: Natural language description of changes
|
|
||||||
current_agent: Current agent JSON
|
|
||||||
library_agents: User's library agents available for sub-agent composition
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
|
||||||
error dict {"type": "error", ...}, or None on error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
|
||||||
return await generate_agent_patch_external(
|
|
||||||
update_request,
|
|
||||||
current_agent,
|
|
||||||
_to_dict_list(library_agents),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def customize_template(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Customize a template/marketplace agent using natural language.
|
|
||||||
|
|
||||||
This is used when users want to modify a template or marketplace agent
|
|
||||||
to fit their specific needs before adding it to their library.
|
|
||||||
|
|
||||||
The external Agent Generator service handles:
|
|
||||||
- Understanding the modification request
|
|
||||||
- Applying changes to the template
|
|
||||||
- Fixing and validating the result
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
|
||||||
error dict {"type": "error", ...}, or None on unexpected error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for customize_template")
|
|
||||||
return await customize_template_external(
|
|
||||||
template_agent, modification_request, context
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Dummy Agent Generator for testing.
|
|
||||||
|
|
||||||
Returns mock responses matching the format expected from the external service.
|
|
||||||
Enable via AGENTGENERATOR_USE_DUMMY=true in settings.
|
|
||||||
|
|
||||||
WARNING: This is for testing only. Do not use in production.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Dummy decomposition result (instructions type)
|
|
||||||
DUMMY_DECOMPOSITION_RESULT: dict[str, Any] = {
|
|
||||||
"type": "instructions",
|
|
||||||
"steps": [
|
|
||||||
{
|
|
||||||
"description": "Get input from user",
|
|
||||||
"action": "input",
|
|
||||||
"block_name": "AgentInputBlock",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"description": "Process the input",
|
|
||||||
"action": "process",
|
|
||||||
"block_name": "TextFormatterBlock",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"description": "Return output to user",
|
|
||||||
"action": "output",
|
|
||||||
"block_name": "AgentOutputBlock",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Block IDs from backend/blocks/io.py
|
|
||||||
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
|
||||||
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_dummy_agent_json() -> dict[str, Any]:
|
|
||||||
"""Generate a minimal valid agent JSON for testing."""
|
|
||||||
input_node_id = str(uuid.uuid4())
|
|
||||||
output_node_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"version": 1,
|
|
||||||
"is_active": True,
|
|
||||||
"name": "Dummy Test Agent",
|
|
||||||
"description": "A dummy agent generated for testing purposes",
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"id": input_node_id,
|
|
||||||
"block_id": AGENT_INPUT_BLOCK_ID,
|
|
||||||
"input_default": {
|
|
||||||
"name": "input",
|
|
||||||
"title": "Input",
|
|
||||||
"description": "Enter your input",
|
|
||||||
"placeholder_values": [],
|
|
||||||
},
|
|
||||||
"metadata": {"position": {"x": 0, "y": 0}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": output_node_id,
|
|
||||||
"block_id": AGENT_OUTPUT_BLOCK_ID,
|
|
||||||
"input_default": {
|
|
||||||
"name": "output",
|
|
||||||
"title": "Output",
|
|
||||||
"description": "Agent output",
|
|
||||||
"format": "{output}",
|
|
||||||
},
|
|
||||||
"metadata": {"position": {"x": 400, "y": 0}},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"links": [
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": input_node_id,
|
|
||||||
"sink_id": output_node_id,
|
|
||||||
"source_name": "result",
|
|
||||||
"sink_name": "value",
|
|
||||||
"is_static": False,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal_dummy(
|
|
||||||
description: str,
|
|
||||||
context: str = "",
|
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return dummy decomposition result."""
|
|
||||||
logger.info("Using dummy agent generator for decompose_goal")
|
|
||||||
return DUMMY_DECOMPOSITION_RESULT.copy()
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_dummy(
|
|
||||||
instructions: dict[str, Any],
|
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
|
||||||
operation_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
|
|
||||||
|
|
||||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
|
||||||
"""
|
|
||||||
logger.info(
|
|
||||||
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(30)
|
|
||||||
return _generate_dummy_agent_json()
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch_dummy(
|
|
||||||
update_request: str,
|
|
||||||
current_agent: dict[str, Any],
|
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
|
||||||
operation_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
|
|
||||||
|
|
||||||
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
|
|
||||||
"""
|
|
||||||
logger.info(
|
|
||||||
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(30)
|
|
||||||
patched = current_agent.copy()
|
|
||||||
patched["description"] = (
|
|
||||||
f"{current_agent.get('description', '')} (updated: {update_request})"
|
|
||||||
)
|
|
||||||
return patched
|
|
||||||
|
|
||||||
|
|
||||||
async def customize_template_dummy(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return dummy customized template (returns template with updated description)."""
|
|
||||||
logger.info("Using dummy agent generator for customize_template")
|
|
||||||
customized = template_agent.copy()
|
|
||||||
customized["description"] = (
|
|
||||||
f"{template_agent.get('description', '')} (customized: {modification_request})"
|
|
||||||
)
|
|
||||||
return customized
|
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_dummy() -> list[dict[str, Any]]:
|
|
||||||
"""Return dummy blocks list."""
|
|
||||||
logger.info("Using dummy agent generator for get_blocks")
|
|
||||||
return [
|
|
||||||
{"id": AGENT_INPUT_BLOCK_ID, "name": "AgentInputBlock"},
|
|
||||||
{"id": AGENT_OUTPUT_BLOCK_ID, "name": "AgentOutputBlock"},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def health_check_dummy() -> bool:
|
|
||||||
"""Always returns healthy for dummy service."""
|
|
||||||
return True
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,67 @@
|
|||||||
|
"""Shared helpers for agent generation."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .blocks import get_blocks_as_dicts
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AGENT_EXECUTOR_BLOCK_ID",
|
||||||
|
"AGENT_INPUT_BLOCK_ID",
|
||||||
|
"AGENT_OUTPUT_BLOCK_ID",
|
||||||
|
"AgentDict",
|
||||||
|
"MCP_TOOL_BLOCK_ID",
|
||||||
|
"UUID_REGEX",
|
||||||
|
"are_types_compatible",
|
||||||
|
"generate_uuid",
|
||||||
|
"get_blocks_as_dicts",
|
||||||
|
"get_defined_property_type",
|
||||||
|
"is_uuid",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for the agent JSON structure passed through
|
||||||
|
# the validation and fixing pipeline.
|
||||||
|
AgentDict = dict[str, Any]
|
||||||
|
|
||||||
|
# Shared base pattern (unanchored, lowercase hex); used for both full-string
|
||||||
|
# validation (UUID_REGEX) and text extraction (core._UUID_PATTERN).
|
||||||
|
UUID_RE_STR = r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[a-f0-9]{4}-[a-f0-9]{12}"
|
||||||
|
|
||||||
|
UUID_REGEX = re.compile(r"^" + UUID_RE_STR + r"$")
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
MCP_TOOL_BLOCK_ID = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||||
|
AGENT_INPUT_BLOCK_ID = "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||||
|
AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||||
|
|
||||||
|
|
||||||
|
def is_uuid(value: str) -> bool:
|
||||||
|
"""Check if a string is a valid UUID."""
|
||||||
|
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uuid() -> str:
|
||||||
|
"""Generate a new UUID string."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def get_defined_property_type(schema: dict[str, Any], name: str) -> str | None:
|
||||||
|
"""Get property type from a schema, handling nested `_#_` notation."""
|
||||||
|
if "_#_" in name:
|
||||||
|
parent, child = name.split("_#_", 1)
|
||||||
|
parent_schema = schema.get(parent, {})
|
||||||
|
if "properties" in parent_schema and isinstance(
|
||||||
|
parent_schema["properties"], dict
|
||||||
|
):
|
||||||
|
return parent_schema["properties"].get(child, {}).get("type")
|
||||||
|
return None
|
||||||
|
return schema.get(name, {}).get("type")
|
||||||
|
|
||||||
|
|
||||||
|
def are_types_compatible(src: str, sink: str) -> bool:
|
||||||
|
"""Check if two schema types are compatible."""
|
||||||
|
if {src, sink} <= {"integer", "number"}:
|
||||||
|
return True
|
||||||
|
return src == sink
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
"""Shared fix → validate → preview/save pipeline for agent tools."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from backend.copilot.tools.models import (
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .blocks import get_blocks_as_dicts
|
||||||
|
from .core import get_library_agents_by_ids, save_agent_to_library
|
||||||
|
from .fixer import AgentFixer
|
||||||
|
from .validator import AgentValidator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_AGENT_JSON_SIZE = 1_000_000 # 1 MB
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_library_agents(
|
||||||
|
user_id: str | None,
|
||||||
|
library_agent_ids: list[str],
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Fetch library agents by IDs for AgentExecutorBlock validation.
|
||||||
|
|
||||||
|
Returns None if no IDs provided or user is not authenticated.
|
||||||
|
"""
|
||||||
|
if not user_id or not library_agent_ids:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
agents = await get_library_agents_by_ids(
|
||||||
|
user_id=user_id,
|
||||||
|
agent_ids=library_agent_ids,
|
||||||
|
)
|
||||||
|
return cast(list[dict[str, Any]], agents)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents by IDs: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def fix_validate_and_save(
|
||||||
|
agent_json: dict[str, Any],
|
||||||
|
*,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str | None,
|
||||||
|
save: bool = True,
|
||||||
|
is_update: bool = False,
|
||||||
|
default_name: str = "Agent",
|
||||||
|
preview_message: str | None = None,
|
||||||
|
save_message: str | None = None,
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
folder_id: str | None = None,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Shared pipeline: auto-fix → validate → preview or save.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: The agent JSON dict (must already have id/version/is_active set).
|
||||||
|
user_id: The authenticated user's ID.
|
||||||
|
session_id: The chat session ID.
|
||||||
|
save: Whether to save or just preview.
|
||||||
|
is_update: Whether this is an update to an existing agent.
|
||||||
|
default_name: Fallback name if agent_json has none.
|
||||||
|
preview_message: Custom preview message (optional).
|
||||||
|
save_message: Custom save success message (optional).
|
||||||
|
library_agents: Library agents for AgentExecutorBlock validation/fixing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An appropriate ToolResponseBase subclass.
|
||||||
|
"""
|
||||||
|
# Size guard
|
||||||
|
json_size = len(json.dumps(agent_json))
|
||||||
|
if json_size > MAX_AGENT_JSON_SIZE:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Agent JSON is too large ({json_size:,} bytes, "
|
||||||
|
f"max {MAX_AGENT_JSON_SIZE:,}). Reduce the number of nodes."
|
||||||
|
),
|
||||||
|
error="agent_json_too_large",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks = get_blocks_as_dicts()
|
||||||
|
|
||||||
|
# Auto-fix
|
||||||
|
try:
|
||||||
|
fixer = AgentFixer()
|
||||||
|
agent_json = fixer.apply_all_fixes(agent_json, blocks, library_agents)
|
||||||
|
fixes = fixer.get_fixes_applied()
|
||||||
|
if fixes:
|
||||||
|
logger.info(f"Applied {len(fixes)} auto-fixes to agent JSON")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Auto-fix failed: {e}")
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
try:
|
||||||
|
validator = AgentValidator()
|
||||||
|
is_valid, _ = validator.validate(agent_json, blocks, library_agents)
|
||||||
|
if not is_valid:
|
||||||
|
errors = validator.errors
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"The agent has {len(errors)} validation error(s):\n"
|
||||||
|
+ "\n".join(f"- {e}" for e in errors[:5])
|
||||||
|
),
|
||||||
|
error="validation_failed",
|
||||||
|
details={"errors": errors},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Validation failed with exception: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to validate the agent. Please try again.",
|
||||||
|
error="validation_exception",
|
||||||
|
details={"exception": str(e)},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_name = agent_json.get("name", default_name)
|
||||||
|
agent_description = agent_json.get("description", "")
|
||||||
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
|
# Build a warning suffix when name/description is missing or generic
|
||||||
|
_GENERIC_NAMES = {
|
||||||
|
"agent",
|
||||||
|
"generated agent",
|
||||||
|
"customized agent",
|
||||||
|
"updated agent",
|
||||||
|
"new agent",
|
||||||
|
"my agent",
|
||||||
|
}
|
||||||
|
metadata_warnings: list[str] = []
|
||||||
|
if not agent_json.get("name") or agent_name.lower().strip() in _GENERIC_NAMES:
|
||||||
|
metadata_warnings.append("'name'")
|
||||||
|
if not agent_description:
|
||||||
|
metadata_warnings.append("'description'")
|
||||||
|
metadata_hint = ""
|
||||||
|
if metadata_warnings:
|
||||||
|
missing = " and ".join(metadata_warnings)
|
||||||
|
metadata_hint = (
|
||||||
|
f" Note: the agent is missing a meaningful {missing}. "
|
||||||
|
f"Please update the agent_json to include them."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not save:
|
||||||
|
return AgentPreviewResponse(
|
||||||
|
message=(
|
||||||
|
(
|
||||||
|
preview_message
|
||||||
|
or f"Agent '{agent_name}' with {node_count} blocks is ready."
|
||||||
|
)
|
||||||
|
+ metadata_hint
|
||||||
|
),
|
||||||
|
agent_json=agent_json,
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
node_count=node_count,
|
||||||
|
link_count=link_count,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="You must be logged in to save agents.",
|
||||||
|
error="auth_required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
agent_json, user_id, is_update=is_update, folder_id=folder_id
|
||||||
|
)
|
||||||
|
return AgentSavedResponse(
|
||||||
|
message=(
|
||||||
|
(save_message or f"Agent '{created_graph.name}' has been saved!")
|
||||||
|
+ metadata_hint
|
||||||
|
),
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_name=created_graph.name,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save agent: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to save the agent: {str(e)}",
|
||||||
|
error="save_failed",
|
||||||
|
details={"exception": str(e)},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user