mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
56 Commits
remove-cla
...
feat/llm-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be328c1ec5 | ||
|
|
8410448c16 | ||
|
|
e168597663 | ||
|
|
1d903ae287 | ||
|
|
1be7aebdea | ||
|
|
36045c7007 | ||
|
|
445eb173a5 | ||
|
|
393a138fee | ||
|
|
ccc1e35c5b | ||
|
|
c66f114e28 | ||
|
|
939edc73b8 | ||
|
|
d52409c853 | ||
|
|
90a68084eb | ||
|
|
fb9a3224be | ||
|
|
eb76b95aa5 | ||
|
|
cc17884360 | ||
|
|
1ce3cc0231 | ||
|
|
bd1f4b5701 | ||
|
|
e89e56d90d | ||
|
|
2a923dcd92 | ||
|
|
1fffd21b16 | ||
|
|
2241a62b75 | ||
|
|
a5b71b9783 | ||
|
|
7632548408 | ||
|
|
05fa10925c | ||
|
|
c64246be87 | ||
|
|
253937e7b9 | ||
|
|
73e481b508 | ||
|
|
f0cc4ae573 | ||
|
|
e0282b00db | ||
|
|
9a9c36b806 | ||
|
|
e86ac21c43 | ||
|
|
d5381625cd | ||
|
|
f6ae3d6593 | ||
|
|
94224be841 | ||
|
|
0fb1b854df | ||
|
|
da4bdc7ab9 | ||
|
|
7176cecf25 | ||
|
|
f35210761c | ||
|
|
1ebcf85669 | ||
|
|
ab7c38bda7 | ||
|
|
b9ce37600e | ||
|
|
3921deaef1 | ||
|
|
64a011664a | ||
|
|
1db7c048d9 | ||
|
|
4c5627c966 | ||
|
|
d97d137a51 | ||
|
|
ded9e293ff | ||
|
|
83d504bed2 | ||
|
|
a5f1ffb35b | ||
|
|
97c6516a14 | ||
|
|
876dde8bc7 | ||
|
|
0bfdd74b25 | ||
|
|
a7d2f81b18 | ||
|
|
3699eaa556 | ||
|
|
21adf9e0fb |
534
.claude/skills/pr-test/SKILL.md
Normal file
534
.claude/skills/pr-test/SKILL.md
Normal file
@@ -0,0 +1,534 @@
|
||||
---
|
||||
name: pr-test
|
||||
description: "E2E manual testing of PRs/branches using docker compose, agent-browser, and API calls. TRIGGER when user asks to manually test a PR, test a feature end-to-end, or run integration tests against a running system."
|
||||
user-invocable: true
|
||||
argument-hint: "[worktree path or PR number] — tests the PR in the given worktree. Optional flags: --fix (auto-fix issues found)"
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Manual E2E Test
|
||||
|
||||
Test a PR/branch end-to-end by building the full platform, interacting via browser and API, capturing screenshots, and reporting results.
|
||||
|
||||
## Arguments
|
||||
|
||||
- `$ARGUMENTS` — worktree path (e.g. `$REPO_ROOT`) or PR number
|
||||
- If `--fix` flag is present, auto-fix bugs found and push fixes (like pr-address loop)
|
||||
|
||||
## Step 0: Resolve the target
|
||||
|
||||
```bash
|
||||
# If argument is a PR number, find its worktree
|
||||
gh pr view {N} --json headRefName --jq '.headRefName'
|
||||
# If argument is a path, use it directly
|
||||
```
|
||||
|
||||
Determine:
|
||||
- `REPO_ROOT` — the root repo directory: `git -C "$WORKTREE_PATH" worktree list | head -1 | awk '{print $1}'` (or `git rev-parse --show-toplevel` if not a worktree)
|
||||
- `WORKTREE_PATH` — the worktree directory
|
||||
- `PLATFORM_DIR` — `$WORKTREE_PATH/autogpt_platform`
|
||||
- `BACKEND_DIR` — `$PLATFORM_DIR/backend`
|
||||
- `FRONTEND_DIR` — `$PLATFORM_DIR/frontend`
|
||||
- `PR_NUMBER` — the PR number (from `gh pr list --head $(git branch --show-current)`)
|
||||
- `PR_TITLE` — the PR title, slugified (e.g. "Add copilot permissions" → "add-copilot-permissions")
|
||||
- `RESULTS_DIR` — `$REPO_ROOT/test-results/PR-{PR_NUMBER}-{slugified-title}`
|
||||
|
||||
Create the results directory:
|
||||
```bash
|
||||
PR_NUMBER=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')
|
||||
PR_TITLE=$(cd $WORKTREE_PATH && gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT --json title --jq '.[0].title' | tr '[:upper:]' '[:lower:]' | sed 's/[^a-z0-9]/-/g' | sed 's/--*/-/g' | sed 's/^-//;s/-$//' | head -c 50)
|
||||
RESULTS_DIR="$REPO_ROOT/test-results/PR-${PR_NUMBER}-${PR_TITLE}"
|
||||
mkdir -p $RESULTS_DIR
|
||||
```
|
||||
|
||||
**Test user credentials** (for logging into the UI or verifying results manually):
|
||||
- Email: `test@test.com`
|
||||
- Password: `testtest123`
|
||||
|
||||
## Step 1: Understand the PR
|
||||
|
||||
Before testing, understand what changed:
|
||||
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git log --oneline dev..HEAD | head -20
|
||||
git diff dev --stat
|
||||
```
|
||||
|
||||
Read the changed files to understand:
|
||||
1. What feature/fix does this PR implement?
|
||||
2. What components are affected? (backend, frontend, copilot, executor, etc.)
|
||||
3. What are the key user-facing behaviors to test?
|
||||
|
||||
## Step 2: Write test scenarios
|
||||
|
||||
Based on the PR analysis, write a test plan to `$RESULTS_DIR/test-plan.md`:
|
||||
|
||||
```markdown
|
||||
# Test Plan: PR #{N} — {title}
|
||||
|
||||
## Scenarios
|
||||
1. [Scenario name] — [what to verify]
|
||||
2. ...
|
||||
|
||||
## API Tests (if applicable)
|
||||
1. [Endpoint] — [expected behavior]
|
||||
|
||||
## UI Tests (if applicable)
|
||||
1. [Page/component] — [interaction to test]
|
||||
|
||||
## Negative Tests
|
||||
1. [What should NOT happen]
|
||||
```
|
||||
|
||||
**Be critical** — include edge cases, error paths, and security checks.
|
||||
|
||||
## Step 3: Environment setup
|
||||
|
||||
### 3a. Copy .env files from the root worktree
|
||||
|
||||
The root worktree (`$REPO_ROOT`) has the canonical `.env` files with all API keys. Copy them to the target worktree:
|
||||
|
||||
```bash
|
||||
# CRITICAL: .env files are NOT checked into git. They must be copied manually.
|
||||
cp $REPO_ROOT/autogpt_platform/.env $PLATFORM_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/backend/.env $BACKEND_DIR/.env
|
||||
cp $REPO_ROOT/autogpt_platform/frontend/.env $FRONTEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3b. Configure copilot authentication
|
||||
|
||||
The copilot needs an LLM API to function. Two approaches (try subscription first):
|
||||
|
||||
#### Option 1: Subscription mode (preferred — uses your Claude Max/Pro subscription)
|
||||
|
||||
The `claude_agent_sdk` Python package **bundles its own Claude CLI binary** — no need to install `@anthropic-ai/claude-code` via npm. The backend auto-provisions credentials from environment variables on startup.
|
||||
|
||||
Run the helper script to extract tokens from your host and auto-update `backend/.env` (works on macOS, Linux, and Windows/WSL):
|
||||
|
||||
```bash
|
||||
# Extracts OAuth tokens and writes CLAUDE_CODE_OAUTH_TOKEN + CLAUDE_CODE_REFRESH_TOKEN into .env
|
||||
bash $BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
**How it works:** The script reads the OAuth token from:
|
||||
- **macOS**: system keychain (`"Claude Code-credentials"`)
|
||||
- **Linux/WSL**: `~/.claude/.credentials.json`
|
||||
- **Windows**: `%APPDATA%/claude/.credentials.json`
|
||||
|
||||
It sets `CLAUDE_CODE_OAUTH_TOKEN`, `CLAUDE_CODE_REFRESH_TOKEN`, and `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` in the `.env` file. On container startup, the backend auto-provisions `~/.claude/.credentials.json` inside the container from these env vars. The SDK's bundled CLI then authenticates using that file. No `claude login`, no npm install needed.
|
||||
|
||||
**Note:** The OAuth token expires (~24h). If copilot returns auth errors, re-run the script and restart: `$BACKEND_DIR/scripts/refresh_claude_token.sh --env-file $BACKEND_DIR/.env && docker compose up -d copilot_executor`
|
||||
|
||||
#### Option 2: OpenRouter API key mode (fallback)
|
||||
|
||||
If subscription mode doesn't work, switch to API key mode using OpenRouter:
|
||||
|
||||
```bash
|
||||
# In $BACKEND_DIR/.env, ensure these are set:
|
||||
CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false
|
||||
CHAT_API_KEY=<value of OPEN_ROUTER_API_KEY from the same .env>
|
||||
CHAT_BASE_URL=https://openrouter.ai/api/v1
|
||||
CHAT_USE_CLAUDE_AGENT_SDK=true
|
||||
```
|
||||
|
||||
Use `sed` to update these values:
|
||||
```bash
|
||||
ORKEY=$(grep "^OPEN_ROUTER_API_KEY=" $BACKEND_DIR/.env | cut -d= -f2)
|
||||
[ -n "$ORKEY" ] || { echo "ERROR: OPEN_ROUTER_API_KEY is missing in $BACKEND_DIR/.env"; exit 1; }
|
||||
perl -i -pe 's/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true/CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=false/' $BACKEND_DIR/.env
|
||||
# Add or update CHAT_API_KEY and CHAT_BASE_URL
|
||||
grep -q "^CHAT_API_KEY=" $BACKEND_DIR/.env && perl -i -pe "s|^CHAT_API_KEY=.*|CHAT_API_KEY=$ORKEY|" $BACKEND_DIR/.env || echo "CHAT_API_KEY=$ORKEY" >> $BACKEND_DIR/.env
|
||||
grep -q "^CHAT_BASE_URL=" $BACKEND_DIR/.env && perl -i -pe 's|^CHAT_BASE_URL=.*|CHAT_BASE_URL=https://openrouter.ai/api/v1|' $BACKEND_DIR/.env || echo "CHAT_BASE_URL=https://openrouter.ai/api/v1" >> $BACKEND_DIR/.env
|
||||
```
|
||||
|
||||
### 3c. Stop conflicting containers
|
||||
|
||||
```bash
|
||||
# Stop any running app containers (keep infra: supabase, redis, rabbitmq, clamav)
|
||||
docker ps --format "{{.Names}}" | grep -E "rest_server|executor|copilot|websocket|database_manager|scheduler|notification|frontend|migrate" | while read name; do
|
||||
docker stop "$name" 2>/dev/null
|
||||
done
|
||||
```
|
||||
|
||||
### 3e. Build and start
|
||||
|
||||
```bash
|
||||
cd $PLATFORM_DIR && docker compose build --no-cache 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker build failed"; exit 1; fi
|
||||
|
||||
cd $PLATFORM_DIR && docker compose up -d 2>&1 | tail -20
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then echo "ERROR: Docker compose up failed"; exit 1; fi
|
||||
```
|
||||
|
||||
**Note:** If the container appears to be running old code (e.g. missing PR changes), use `docker compose build --no-cache` to force a full rebuild. Docker BuildKit may sometimes reuse cached `COPY` layers from a previous build on a different branch.
|
||||
|
||||
**Expected time: 3-8 minutes** for build, 5-10 minutes with `--no-cache`.
|
||||
|
||||
### 3f. Wait for services to be ready
|
||||
|
||||
```bash
|
||||
# Poll until backend and frontend respond
|
||||
for i in $(seq 1 60); do
|
||||
BACKEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8006/docs 2>/dev/null)
|
||||
FRONTEND=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:3000 2>/dev/null)
|
||||
if [ "$BACKEND" = "200" ] && [ "$FRONTEND" = "200" ]; then
|
||||
echo "Services ready"
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
### 3h. Create test user and get auth token
|
||||
|
||||
```bash
|
||||
ANON_KEY=$(grep "NEXT_PUBLIC_SUPABASE_ANON_KEY=" $FRONTEND_DIR/.env | sed 's/.*NEXT_PUBLIC_SUPABASE_ANON_KEY=//' | tr -d '[:space:]')
|
||||
|
||||
# Signup (idempotent — returns "User already registered" if exists)
|
||||
RESULT=$(curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}')
|
||||
|
||||
# If "Database error finding user", restart supabase-auth and retry
|
||||
if echo "$RESULT" | grep -q "Database error"; then
|
||||
docker restart supabase-auth && sleep 5
|
||||
curl -s -X POST 'http://localhost:8000/auth/v1/signup' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}'
|
||||
fi
|
||||
|
||||
# Get auth token
|
||||
TOKEN=$(curl -s -X POST 'http://localhost:8000/auth/v1/token?grant_type=password' \
|
||||
-H "apikey: $ANON_KEY" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"email":"test@test.com","password":"testtest123"}' | jq -r '.access_token // ""')
|
||||
```
|
||||
|
||||
**Use this token for ALL API calls:**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/...
|
||||
```
|
||||
|
||||
## Step 4: Run tests
|
||||
|
||||
### Service ports reference
|
||||
|
||||
| Service | Port | URL |
|
||||
|---------|------|-----|
|
||||
| Frontend | 3000 | http://localhost:3000 |
|
||||
| Backend REST | 8006 | http://localhost:8006 |
|
||||
| Supabase Auth (via Kong) | 8000 | http://localhost:8000 |
|
||||
| Executor | 8002 | http://localhost:8002 |
|
||||
| Copilot Executor | 8008 | http://localhost:8008 |
|
||||
| WebSocket | 8001 | http://localhost:8001 |
|
||||
| Database Manager | 8005 | http://localhost:8005 |
|
||||
| Redis | 6379 | localhost:6379 |
|
||||
| RabbitMQ | 5672 | localhost:5672 |
|
||||
|
||||
### API testing
|
||||
|
||||
Use `curl` with the auth token for backend API tests:
|
||||
|
||||
```bash
|
||||
# Example: List agents
|
||||
curl -s -H "Authorization: Bearer $TOKEN" http://localhost:8006/api/graphs | jq . | head -20
|
||||
|
||||
# Example: Create an agent
|
||||
curl -s -X POST http://localhost:8006/api/graphs \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{...}' | jq .
|
||||
|
||||
# Example: Run an agent
|
||||
curl -s -X POST "http://localhost:8006/api/graphs/{graph_id}/execute" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"data": {...}}'
|
||||
|
||||
# Example: Get execution results
|
||||
curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8006/api/graphs/{graph_id}/executions/{exec_id}" | jq .
|
||||
```
|
||||
|
||||
### Browser testing with agent-browser
|
||||
|
||||
```bash
|
||||
# Close any existing session
|
||||
agent-browser close 2>/dev/null || true
|
||||
|
||||
# Use --session-name to persist cookies across navigations
|
||||
# This means login only needs to happen once per test session
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/login' --timeout 15000
|
||||
|
||||
# Get interactive elements
|
||||
agent-browser --session-name pr-test snapshot | grep "textbox\|button"
|
||||
|
||||
# Login
|
||||
agent-browser --session-name pr-test fill {email_ref} "test@test.com"
|
||||
agent-browser --session-name pr-test fill {password_ref} "testtest123"
|
||||
agent-browser --session-name pr-test click {login_button_ref}
|
||||
sleep 5
|
||||
|
||||
# Dismiss cookie banner if present
|
||||
agent-browser --session-name pr-test click 'text=Accept All' 2>/dev/null || true
|
||||
|
||||
# Navigate — cookies are preserved so login persists
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
|
||||
# Take screenshot
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/01-page.png
|
||||
|
||||
# Interact with elements
|
||||
agent-browser --session-name pr-test fill {ref} "text"
|
||||
agent-browser --session-name pr-test press "Enter"
|
||||
agent-browser --session-name pr-test click {ref}
|
||||
agent-browser --session-name pr-test click 'text=Button Text'
|
||||
|
||||
# Read page content
|
||||
agent-browser --session-name pr-test snapshot | grep "text:"
|
||||
```
|
||||
|
||||
**Key pages:**
|
||||
- `/copilot` — CoPilot chat (for testing copilot features)
|
||||
- `/build` — Agent builder (for testing block/node features)
|
||||
- `/build?flowID={id}` — Specific agent in builder
|
||||
- `/library` — Agent library (for testing listing/import features)
|
||||
- `/library/agents/{id}` — Agent detail with run history
|
||||
- `/marketplace` — Marketplace
|
||||
|
||||
### Checking logs
|
||||
|
||||
```bash
|
||||
# Backend REST server
|
||||
docker logs autogpt_platform-rest_server-1 2>&1 | tail -30
|
||||
|
||||
# Executor (runs agent graphs)
|
||||
docker logs autogpt_platform-executor-1 2>&1 | tail -30
|
||||
|
||||
# Copilot executor (runs copilot chat sessions)
|
||||
docker logs autogpt_platform-copilot_executor-1 2>&1 | tail -30
|
||||
|
||||
# Frontend
|
||||
docker logs autogpt_platform-frontend-1 2>&1 | tail -30
|
||||
|
||||
# Filter for errors
|
||||
docker logs autogpt_platform-executor-1 2>&1 | grep -i "error\|exception\|traceback" | tail -20
|
||||
```
|
||||
|
||||
### Copilot chat testing
|
||||
|
||||
The copilot uses SSE streaming. To test via API:
|
||||
|
||||
```bash
|
||||
# Create a session
|
||||
SESSION_ID=$(curl -s -X POST 'http://localhost:8006/api/chat/sessions' \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}' | jq -r '.id // .session_id // ""')
|
||||
|
||||
# Stream a message (SSE - will stream chunks)
|
||||
curl -N -X POST "http://localhost:8006/api/chat/sessions/$SESSION_ID/stream" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"message": "Hello, what can you help me with?"}' \
|
||||
--max-time 60 2>/dev/null | head -50
|
||||
```
|
||||
|
||||
Or test via browser (preferred for UI verification):
|
||||
```bash
|
||||
agent-browser --session-name pr-test open 'http://localhost:3000/copilot' --timeout 10000
|
||||
# ... fill chat input and press Enter, wait 20-30s for response
|
||||
```
|
||||
|
||||
## Step 5: Record results
|
||||
|
||||
For each test scenario, record in `$RESULTS_DIR/test-report.md`:
|
||||
|
||||
```markdown
|
||||
# E2E Test Report: PR #{N} — {title}
|
||||
Date: {date}
|
||||
Branch: {branch}
|
||||
Worktree: {path}
|
||||
|
||||
## Environment
|
||||
- Docker services: [list running containers]
|
||||
- API keys: OpenRouter={present/missing}, E2B={present/missing}
|
||||
|
||||
## Test Results
|
||||
|
||||
### Scenario 1: {name}
|
||||
**Steps:**
|
||||
1. ...
|
||||
2. ...
|
||||
**Expected:** ...
|
||||
**Actual:** ...
|
||||
**Result:** PASS / FAIL
|
||||
**Screenshot:** {filename}.png
|
||||
**Logs:** (if relevant)
|
||||
|
||||
### Scenario 2: {name}
|
||||
...
|
||||
|
||||
## Summary
|
||||
- Total: X scenarios
|
||||
- Passed: Y
|
||||
- Failed: Z
|
||||
- Bugs found: [list]
|
||||
```
|
||||
|
||||
Take screenshots at each significant step:
|
||||
```bash
|
||||
agent-browser --session-name pr-test screenshot $RESULTS_DIR/{NN}-{description}.png
|
||||
```
|
||||
|
||||
## Step 6: Report results
|
||||
|
||||
After all tests complete, output a summary to the user:
|
||||
|
||||
1. Table of all scenarios with PASS/FAIL
|
||||
2. Screenshots of failures (read the PNG files to show them)
|
||||
3. Any bugs found with details
|
||||
4. Recommendations
|
||||
|
||||
### Post test results as PR comment with screenshots
|
||||
|
||||
Upload screenshots to the PR using the GitHub Git API (no local git operations — safe for worktrees).
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
SCREENSHOTS_BRANCH="test-screenshots/pr-${PR_NUMBER}"
|
||||
SCREENSHOTS_DIR="test-screenshots/PR-${PR_NUMBER}"
|
||||
|
||||
# Step 1: Create blobs for each screenshot
|
||||
declare -a TREE_ENTRIES
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
|
||||
TREE_ENTRIES+=("-f" "tree[][path]=${SCREENSHOTS_DIR}/${BASENAME}" "-f" "tree[][mode]=100644" "-f" "tree[][type]=blob" "-f" "tree[][sha]=${BLOB_SHA}")
|
||||
done
|
||||
|
||||
# Step 2: Create a tree with all screenshot blobs
|
||||
# Build the tree JSON manually since gh api doesn't handle arrays well
|
||||
TREE_JSON='['
|
||||
FIRST=true
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
BASENAME=$(basename "$img")
|
||||
B64=$(base64 < "$img")
|
||||
BLOB_SHA=$(gh api "repos/${REPO}/git/blobs" -f content="$B64" -f encoding="base64" --jq '.sha')
|
||||
if [ "$FIRST" = true ]; then FIRST=false; else TREE_JSON+=','; fi
|
||||
TREE_JSON+="{\"path\":\"${SCREENSHOTS_DIR}/${BASENAME}\",\"mode\":\"100644\",\"type\":\"blob\",\"sha\":\"${BLOB_SHA}\"}"
|
||||
done
|
||||
TREE_JSON+=']'
|
||||
|
||||
TREE_SHA=$(echo "$TREE_JSON" | gh api "repos/${REPO}/git/trees" --input - -f base_tree="" --jq '.sha' 2>/dev/null \
|
||||
|| echo "$TREE_JSON" | jq -c '{tree: .}' | gh api "repos/${REPO}/git/trees" --input - --jq '.sha')
|
||||
|
||||
# Step 3: Create a commit pointing to that tree
|
||||
COMMIT_SHA=$(gh api "repos/${REPO}/git/commits" \
|
||||
-f message="test: add E2E test screenshots for PR #${PR_NUMBER}" \
|
||||
-f tree="$TREE_SHA" \
|
||||
--jq '.sha')
|
||||
|
||||
# Step 4: Create or update the ref (branch) — no local checkout needed
|
||||
gh api "repos/${REPO}/git/refs" \
|
||||
-f ref="refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-f sha="$COMMIT_SHA" 2>/dev/null \
|
||||
|| gh api "repos/${REPO}/git/refs/heads/${SCREENSHOTS_BRANCH}" \
|
||||
-X PATCH -f sha="$COMMIT_SHA" -f force=true
|
||||
|
||||
# Step 5: Build image markdown and post the comment
|
||||
REPO_URL="https://raw.githubusercontent.com/${REPO}/${SCREENSHOTS_BRANCH}"
|
||||
IMAGE_MARKDOWN=""
|
||||
for img in $RESULTS_DIR/*.png; do
|
||||
BASENAME=$(basename "$img")
|
||||
IMAGE_MARKDOWN="$IMAGE_MARKDOWN
|
||||
"
|
||||
done
|
||||
|
||||
gh api "repos/${REPO}/issues/$PR_NUMBER/comments" -f body="$(cat <<EOF
|
||||
## 🧪 E2E Test Report
|
||||
|
||||
$(cat $RESULTS_DIR/test-report.md)
|
||||
|
||||
### Screenshots
|
||||
${IMAGE_MARKDOWN}
|
||||
EOF
|
||||
)"
|
||||
```
|
||||
|
||||
This approach uses the GitHub Git API to create blobs, trees, commits, and refs entirely server-side. No local `git checkout` or `git push` — safe for worktrees and won't interfere with the PR branch.
|
||||
|
||||
## Fix mode (--fix flag)
|
||||
|
||||
When `--fix` is present, after finding a bug:
|
||||
|
||||
1. Identify the root cause in the code
|
||||
2. Fix it in the worktree
|
||||
3. Rebuild the affected service: `cd $PLATFORM_DIR && docker compose up --build -d {service_name}`
|
||||
4. Re-test the scenario
|
||||
5. If fix works, commit and push:
|
||||
```bash
|
||||
cd $WORKTREE_PATH
|
||||
git add -A
|
||||
git commit -m "fix: {description of fix}"
|
||||
git push
|
||||
```
|
||||
6. Continue testing remaining scenarios
|
||||
7. After all fixes, run the full test suite again to ensure no regressions
|
||||
|
||||
### Fix loop (like pr-address)
|
||||
|
||||
```text
|
||||
test scenario → find bug → fix code → rebuild service → re-test
|
||||
→ repeat until all scenarios pass
|
||||
→ commit + push all fixes
|
||||
→ run full re-test to verify
|
||||
```
|
||||
|
||||
## Known issues and workarounds
|
||||
|
||||
### Problem: "Database error finding user" on signup
|
||||
**Cause:** Supabase auth service schema cache is stale after migration.
|
||||
**Fix:** `docker restart supabase-auth && sleep 5` then retry signup.
|
||||
|
||||
### Problem: Copilot returns auth errors in subscription mode
|
||||
**Cause:** `CHAT_USE_CLAUDE_CODE_SUBSCRIPTION=true` but `CLAUDE_CODE_OAUTH_TOKEN` is not set or expired.
|
||||
**Fix:** Re-extract the OAuth token from macOS keychain (see step 3b, Option 1) and recreate the container (`docker compose up -d copilot_executor`). The backend auto-provisions `~/.claude/.credentials.json` from the env var on startup. No `npm install` or `claude login` needed — the SDK bundles its own CLI binary.
|
||||
|
||||
### Problem: agent-browser can't find chromium
|
||||
**Cause:** The Dockerfile auto-provisions system chromium on all architectures (including ARM64). If your branch is behind `dev`, this may not be present yet.
|
||||
**Fix:** Check if chromium exists: `which chromium || which chromium-browser`. If missing, install it: `apt-get install -y chromium` and set `AGENT_BROWSER_EXECUTABLE_PATH=/usr/bin/chromium` in the container environment.
|
||||
|
||||
### Problem: agent-browser selector matches multiple elements
|
||||
**Cause:** `text=X` matches all elements containing that text.
|
||||
**Fix:** Use `agent-browser snapshot` to get specific `ref=eNN` references, then use those: `agent-browser click eNN`.
|
||||
|
||||
### Problem: Frontend shows cookie banner blocking interaction
|
||||
**Fix:** `agent-browser click 'text=Accept All'` before other interactions.
|
||||
|
||||
### Problem: Container loses npm packages after rebuild
|
||||
**Cause:** `docker compose up --build` rebuilds the image, losing runtime installs.
|
||||
**Fix:** Add packages to the Dockerfile instead of installing at runtime.
|
||||
|
||||
### Problem: Services not starting after `docker compose up`
|
||||
**Fix:** Wait and check health: `docker compose ps`. Common cause: migration hasn't finished. Check: `docker logs autogpt_platform-migrate-1 2>&1 | tail -5`. If supabase-db isn't healthy: `docker restart supabase-db && sleep 10`.
|
||||
|
||||
### Problem: Docker uses cached layers with old code (PR changes not visible)
|
||||
**Cause:** `docker compose up --build` reuses cached `COPY` layers from previous builds. If the PR branch changes Python files but the previous build already cached that layer from `dev`, the container runs `dev` code.
|
||||
**Fix:** Always use `docker compose build --no-cache` for the first build of a PR branch. Subsequent rebuilds within the same branch can use `--build`.
|
||||
|
||||
### Problem: `agent-browser open` loses login session
|
||||
**Cause:** Without session persistence, `agent-browser open` starts fresh.
|
||||
**Fix:** Use `--session-name pr-test` on ALL agent-browser commands. This auto-saves/restores cookies and localStorage across navigations. Alternatively, use `agent-browser eval "window.location.href = '...'"` to navigate within the same context.
|
||||
|
||||
### Problem: Supabase auth returns "Database error querying schema"
|
||||
**Cause:** The database schema changed (migration ran) but supabase-auth has a stale schema cache.
|
||||
**Fix:** `docker restart supabase-db && sleep 10 && docker restart supabase-auth && sleep 8`. If user data was lost, re-signup.
|
||||
@@ -592,6 +592,11 @@ async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
"""Configure auto top-up settings and perform an immediate top-up if needed.
|
||||
|
||||
Raises HTTPException(422) if the request parameters are invalid or if
|
||||
the credit top-up fails.
|
||||
"""
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
@@ -606,10 +611,20 @@ async def configure_user_auto_top_up(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
try:
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
except ValueError as e:
|
||||
known_messages = (
|
||||
"must not be negative",
|
||||
"already exists for user",
|
||||
"No payment method found",
|
||||
)
|
||||
if any(msg in str(e) for msg in known_messages):
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
raise
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
|
||||
@@ -188,6 +188,7 @@ async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
overwrite: bool = Query(default=False),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
@@ -248,7 +249,9 @@ async def upload_file(
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
workspace_file = await manager.write_file(
|
||||
content, filename, overwrite=overwrite
|
||||
)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
@@ -37,8 +38,10 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
@@ -117,16 +120,56 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Load LLM registry before initializing blocks so blocks can use registry data.
|
||||
# Tries Redis first (fast path on warm restart), falls back to DB.
|
||||
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
|
||||
try:
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry loaded successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load LLM registry at startup: {e}. "
|
||||
"Blocks will initialize with empty registry."
|
||||
)
|
||||
|
||||
# Start background task so this worker reloads its in-process cache whenever
|
||||
# another worker (e.g. the admin API) refreshes the registry.
|
||||
_registry_subscription_task = asyncio.create_task(
|
||||
backend.data.llm_registry.subscribe_to_registry_refresh(
|
||||
backend.data.llm_registry.refresh_llm_registry
|
||||
)
|
||||
)
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
if "AgentNode" in err_str or "does not exist" in err_str:
|
||||
logger.warning(
|
||||
f"migrate_llm_models skipped: AgentNode table not found ({e}). "
|
||||
"This is expected in test environments."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"migrate_llm_models failed unexpectedly: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
_registry_subscription_task.cancel()
|
||||
try:
|
||||
await _registry_subscription_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
@@ -210,13 +253,22 @@ instrument_fastapi(
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
if log_error:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
if status_code >= 500:
|
||||
logger.exception(
|
||||
"%s %s failed. Investigate and resolve the underlying issue: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s %s failed with %d: %s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
status_code,
|
||||
exc,
|
||||
)
|
||||
|
||||
hint = (
|
||||
"Adjust the request and retry."
|
||||
@@ -266,12 +318,10 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(FolderAlreadyExistsError, handle_internal_http_error(409))
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
@@ -348,6 +398,16 @@ app.include_router(
|
||||
tags=["oauth"],
|
||||
prefix="/api/oauth",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.llm.admin_router,
|
||||
tags=["v2", "llm", "admin"],
|
||||
prefix="/api",
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_api)
|
||||
|
||||
|
||||
@@ -796,6 +796,19 @@ async def llm_call(
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Sanitize unpaired surrogates in message content to prevent
|
||||
# UnicodeEncodeError when httpx encodes the JSON request body.
|
||||
for msg in prompt:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Sanitized unpaired surrogates in LLM prompt content")
|
||||
msg["content"] = content.encode("utf-8", errors="surrogatepass").decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
|
||||
@@ -934,7 +934,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution with manager failed: {e}")
|
||||
logger.warning(f"Tool execution with manager failed: {e}")
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
tool_call.id,
|
||||
|
||||
@@ -12,34 +12,18 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = f"""\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
### Sharing files
|
||||
After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- File: `[report.csv](workspace://file_id#text/csv)`
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
- `/absolute/path` — local/sandbox files
|
||||
- `[start-end]` — optional 1-indexed line range
|
||||
- Multiple refs per argument supported. Only `workspace://` and absolute paths are expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@ -50,21 +34,9 @@ Examples:
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
**Structured data**: When the entire argument is a single file reference, the platform auto-parses by extension/MIME. Supported: JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel (.xlsx only; legacy `.xls` is NOT supported). Unrecognised formats return plain string.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
**Type coercion**: The platform auto-coerces expanded string values to match block input types (e.g. JSON string → `list[list[str]]`).
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
@@ -166,17 +138,12 @@ def _build_storage_supplement(
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
### Shell & filesystem
|
||||
- The SDK built-in Bash tool is NOT available. Use `bash_exec` for shell commands ({sandbox_type}). Working dir: `{working_dir}`
|
||||
- SDK file tools (Read/Write/Edit/Glob/Grep) and `bash_exec` share one filesystem — use relative or absolute paths under this dir.
|
||||
- `read_workspace_file`/`write_workspace_file` operate on **persistent cloud workspace storage** (separate from the working dir).
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
@@ -77,6 +75,7 @@ from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .subscription import validate_subscription as _validate_claude_code_subscription
|
||||
from .tool_adapter import (
|
||||
create_copilot_mcp_server,
|
||||
get_copilot_tool_names,
|
||||
@@ -458,37 +457,6 @@ def _resolve_sdk_model() -> str | None:
|
||||
return model
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _validate_claude_code_subscription() -> None:
|
||||
"""Validate Claude CLI is installed and responds to `--version`.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. A failure (CLI not installed) is a config error that requires
|
||||
a process restart anyway.
|
||||
"""
|
||||
claude_path = shutil.which("claude")
|
||||
if not claude_path:
|
||||
raise RuntimeError(
|
||||
"Claude Code CLI not found. Install it with: "
|
||||
"npm install -g @anthropic-ai/claude-code"
|
||||
)
|
||||
result = subprocess.run(
|
||||
[claude_path, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
|
||||
def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
|
||||
144
autogpt_platform/backend/backend/copilot/sdk/subscription.py
Normal file
144
autogpt_platform/backend/backend/copilot/sdk/subscription.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Claude Code subscription auth helpers.
|
||||
|
||||
Handles locating the SDK-bundled CLI binary, provisioning credentials from
|
||||
environment variables, and validating that subscription auth is functional.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_bundled_cli() -> str:
|
||||
"""Locate the Claude CLI binary bundled inside ``claude_agent_sdk``.
|
||||
|
||||
Falls back to ``shutil.which("claude")`` if the SDK bundle is absent.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if path:
|
||||
return str(path)
|
||||
except Exception:
|
||||
pass
|
||||
system_path = shutil.which("claude")
|
||||
if system_path:
|
||||
return system_path
|
||||
raise RuntimeError(
|
||||
"Claude CLI not found — neither the SDK-bundled binary nor a "
|
||||
"system-installed `claude` could be located."
|
||||
)
|
||||
|
||||
|
||||
def provision_credentials_file() -> None:
|
||||
"""Write ``~/.claude/.credentials.json`` from env when running headless.
|
||||
|
||||
If ``CLAUDE_CODE_OAUTH_TOKEN`` is set (an OAuth *access* token obtained
|
||||
from ``claude auth status`` or extracted from the macOS keychain), this
|
||||
helper writes a minimal credentials file so the bundled CLI can
|
||||
authenticate without an interactive ``claude login``.
|
||||
|
||||
A ``CLAUDE_CODE_REFRESH_TOKEN`` env var is optional but recommended —
|
||||
it lets the CLI silently refresh an expired access token.
|
||||
"""
|
||||
access_token = os.environ.get("CLAUDE_CODE_OAUTH_TOKEN", "").strip()
|
||||
if not access_token:
|
||||
return
|
||||
|
||||
creds_dir = os.path.expanduser("~/.claude")
|
||||
creds_path = os.path.join(creds_dir, ".credentials.json")
|
||||
|
||||
# Don't overwrite an existing credentials file (e.g. from a volume mount).
|
||||
if os.path.exists(creds_path):
|
||||
logger.debug("Credentials file already exists at %s — skipping", creds_path)
|
||||
return
|
||||
|
||||
os.makedirs(creds_dir, exist_ok=True)
|
||||
|
||||
creds = {
|
||||
"claudeAiOauth": {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": os.environ.get("CLAUDE_CODE_REFRESH_TOKEN", "").strip(),
|
||||
"expiresAt": 0,
|
||||
"scopes": [
|
||||
"user:inference",
|
||||
"user:profile",
|
||||
"user:sessions:claude_code",
|
||||
],
|
||||
}
|
||||
}
|
||||
with open(creds_path, "w") as f:
|
||||
json.dump(creds, f)
|
||||
logger.info("Provisioned Claude credentials file at %s", creds_path)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def validate_subscription() -> None:
|
||||
"""Validate the bundled Claude CLI is reachable and authenticated.
|
||||
|
||||
Cached so the blocking subprocess check runs at most once per process
|
||||
lifetime. On first call, also provisions ``~/.claude/.credentials.json``
|
||||
from the ``CLAUDE_CODE_OAUTH_TOKEN`` env var when available.
|
||||
"""
|
||||
provision_credentials_file()
|
||||
|
||||
cli = find_bundled_cli()
|
||||
result = subprocess.run(
|
||||
[cli, "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Claude CLI check failed (exit {result.returncode}): "
|
||||
f"{result.stderr.strip()}"
|
||||
)
|
||||
logger.info(
|
||||
"Claude Code subscription mode: CLI version %s",
|
||||
result.stdout.strip(),
|
||||
)
|
||||
|
||||
# Verify the CLI is actually authenticated.
|
||||
auth_result = subprocess.run(
|
||||
[cli, "auth", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
env={
|
||||
**os.environ,
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
},
|
||||
)
|
||||
if auth_result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Claude CLI is not authenticated. Either:\n"
|
||||
" • Set CLAUDE_CODE_OAUTH_TOKEN env var (from `claude auth status` "
|
||||
"or macOS keychain), or\n"
|
||||
" • Mount ~/.claude/.credentials.json into the container, or\n"
|
||||
" • Run `claude login` inside the container."
|
||||
)
|
||||
try:
|
||||
status = json.loads(auth_result.stdout)
|
||||
if not status.get("loggedIn"):
|
||||
raise RuntimeError(
|
||||
"Claude CLI reports loggedIn=false. Set CLAUDE_CODE_OAUTH_TOKEN "
|
||||
"or run `claude login`."
|
||||
)
|
||||
logger.info(
|
||||
"Claude subscription auth: method=%s, email=%s",
|
||||
status.get("authMethod"),
|
||||
status.get("email"),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Could not parse `claude auth status` output")
|
||||
@@ -22,13 +22,12 @@ class AddUnderstandingTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
return (
|
||||
"Store user's business context, workflows, pain points, and automation goals. "
|
||||
"Call whenever the user shares business info. Each call incrementally merges "
|
||||
"with existing data — provide only the fields you have. "
|
||||
"Builds a profile that helps recommend better agents for the user's needs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -410,18 +410,11 @@ class BrowserNavigateTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
"Navigate to a URL in a real browser. Returns accessibility tree with @ref IDs "
|
||||
"for browser_act. Session persists (cookies/auth carry over). "
|
||||
"For static pages, prefer web_fetch. "
|
||||
"For SPAs, elements may load late — use browser_act with wait + browser_screenshot to verify. "
|
||||
"For auth: navigate to login, fill creds and submit with browser_act, then navigate to target."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -431,13 +424,13 @@ class BrowserNavigateTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
"description": "HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
"description": "Navigation completion strategy (default: networkidle).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
@@ -556,14 +549,12 @@ class BrowserActTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"Interact with the current browser page using @ref IDs from the snapshot. "
|
||||
"Actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
"fill clears field first; type appends. "
|
||||
"wait accepts CSS selector or milliseconds (e.g. '1000'). "
|
||||
"Returns updated snapshot."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -589,30 +580,21 @@ class BrowserActTool(BaseTool):
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "The action to perform.",
|
||||
"description": "Action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
"description": "@ref ID (e.g. '@e3'), CSS selector, or text. Required for: click, dblclick, fill, type, hover, check, uncheck, select. For wait: CSS selector or milliseconds string (e.g. '1000').",
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
"description": "Text for fill/type, key for press (e.g. 'Enter'), option for select.",
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
"description": "Scroll direction (default: down).",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
@@ -759,12 +741,10 @@ class BrowserScreenshotTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
"Screenshot the current browser page and save to workspace. "
|
||||
"annotate=true overlays @ref labels on elements. "
|
||||
"IMPORTANT: After calling, you MUST immediately call read_workspace_file with the "
|
||||
"returned file_id to display the image inline."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -775,12 +755,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
"description": "Overlay @ref labels (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Filename to save in the workspace.",
|
||||
"description": "Workspace filename (default: screenshot.png).",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,22 +108,12 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
|
||||
Wait for completion (optional):
|
||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
||||
If the execution is running/queued, waits up to this many seconds for completion.
|
||||
Returns current status on timeout. If already finished, returns immediately.
|
||||
"""
|
||||
return (
|
||||
"Retrieve execution outputs from a library agent. "
|
||||
"Identify by agent_name, library_agent_id, or store_slug. "
|
||||
"Filter by execution_id or run_time. "
|
||||
"Optionally wait for running executions."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -132,32 +122,29 @@ class AgentOutputTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
"description": "Agent name (fuzzy match).",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Exact library agent ID",
|
||||
"description": "Library agent ID.",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
"description": "Marketplace 'username/agent-name'.",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
"description": "Specific execution ID.",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
"description": "Time filter: 'latest', 'today', 'yesterday', 'last week', 'last 7 days', 'last month', 'last 30 days', 'YYYY-MM-DD', or ISO datetime.",
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait if execution is still running (0-300). "
|
||||
"If running, waits for completion. Returns current state on timeout."
|
||||
),
|
||||
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -42,15 +42,9 @@ class BashExecTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a Bash command or script. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
|
||||
"Useful for scripts, data processing, and package installation. "
|
||||
"Killed after timeout (default 30s, max 120s)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -60,13 +54,11 @@ class BashExecTool(BaseTool):
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command or script to execute.",
|
||||
"description": "Bash command or script.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max execution time in seconds (default 30, max 120)."
|
||||
),
|
||||
"description": "Max seconds (default 30, max 120).",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -30,12 +30,7 @@ class ContinueRunBlockTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Continue executing a block after human review approval. "
|
||||
"Use this after a run_block call returned review_required. "
|
||||
"Pass the review_id from the review_required response. "
|
||||
"The block will execute with the original pre-approved input data."
|
||||
)
|
||||
return "Resume block execution after a run_block call returned review_required. Pass the review_id."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -44,10 +39,7 @@ class ContinueRunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"review_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The review_id from a previous review_required response. "
|
||||
"This resumes execution with the pre-approved input data."
|
||||
),
|
||||
"description": "review_id from the review_required response.",
|
||||
},
|
||||
},
|
||||
"required": ["review_id"],
|
||||
|
||||
@@ -23,12 +23,8 @@ class CreateAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow. Pass `agent_json` with the complete "
|
||||
"agent graph JSON you generated using block schemas from find_block. "
|
||||
"The tool validates, auto-fixes, and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter."
|
||||
"Create a new agent from JSON (nodes + links). Validates, auto-fixes, and saves. "
|
||||
"Before calling, search for existing agents with find_library_agent."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -42,34 +38,21 @@ class CreateAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links' arrays, and optionally "
|
||||
"'name' and 'description'."
|
||||
),
|
||||
"description": "Agent graph with 'nodes' and 'links' arrays.",
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent. Default is true. "
|
||||
"Set to false for preview only."
|
||||
),
|
||||
"description": "Save the agent (default: true). False for preview.",
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
"description": "Folder ID to save into (default: root).",
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -23,9 +23,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace or template agent. Pass `agent_json` "
|
||||
"with the complete customized agent JSON. The tool validates, "
|
||||
"auto-fixes, and saves."
|
||||
"Customize a marketplace/template agent. Validates, auto-fixes, and saves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -39,32 +37,21 @@ class CustomizeAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Complete customized agent JSON to validate and save. "
|
||||
"Optionally include 'name' and 'description'."
|
||||
),
|
||||
"description": "Customized agent JSON with nodes and links.",
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the customized agent. Default is true."
|
||||
),
|
||||
"description": "Save the agent (default: true). False for preview.",
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
"description": "Folder ID to save into (default: root).",
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -23,12 +23,8 @@ class EditAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent. Pass `agent_json` with the complete "
|
||||
"updated agent JSON you generated. The tool validates, auto-fixes, "
|
||||
"and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks."
|
||||
"Edit an existing agent. Validates, auto-fixes, and saves. "
|
||||
"Before calling, search for existing agents with find_library_agent."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -42,33 +38,20 @@ class EditAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
"description": "Graph ID or library agent ID to edit.",
|
||||
},
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Complete updated agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links'. "
|
||||
"Include 'name' and/or 'description' if they need "
|
||||
"to be updated."
|
||||
),
|
||||
"description": "Updated agent JSON with nodes and links.",
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes."
|
||||
),
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"description": "Save changes (default: true). False for preview.",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -134,11 +134,7 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search existing feature requests to check if a similar request "
|
||||
"already exists before creating a new one. Returns matching feature "
|
||||
"requests with their ID, title, and description."
|
||||
)
|
||||
return "Search existing feature requests. Check before creating a new one."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -234,14 +230,9 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate. "
|
||||
"IMPORTANT: Never include personally identifiable information (PII) in "
|
||||
"the title or description — no names, emails, phone numbers, company "
|
||||
"names, or other identifying details. Write titles and descriptions in "
|
||||
"generic, feature-focused language."
|
||||
"Create a feature request or add need to existing one. "
|
||||
"Search first to avoid duplicates. Pass existing_issue_id to add to existing. "
|
||||
"Never include PII (names, emails, phone numbers, company names) in title/description."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -251,28 +242,15 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Title for the feature request. Must be generic and "
|
||||
"feature-focused — do not include any user names, emails, "
|
||||
"company names, or other PII."
|
||||
),
|
||||
"description": "Feature request title. No names, emails, or company info.",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Detailed description of what the user wants and why. "
|
||||
"Must not contain any personally identifiable information "
|
||||
"(PII) — describe the feature need generically without "
|
||||
"referencing specific users, companies, or contact details."
|
||||
),
|
||||
"description": "What the user wants and why. No names, emails, or company info.",
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If adding a need to an existing feature request, "
|
||||
"provide its Linear issue ID (from search results). "
|
||||
"Omit to create a new feature request."
|
||||
),
|
||||
"description": "Linear issue ID to add need to (from search results).",
|
||||
},
|
||||
},
|
||||
"required": ["title", "description"],
|
||||
|
||||
@@ -18,10 +18,7 @@ class FindAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and "
|
||||
"user needs, or look up a specific agent by its creator/slug ID."
|
||||
)
|
||||
return "Search marketplace agents by capability, or look up by slug ('username/agent-name')."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -30,7 +27,7 @@ class FindAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query describing what the user wants to accomplish, or a creator/slug ID (e.g. 'username/agent-name') for direct lookup. Use single keywords for best results.",
|
||||
"description": "Search keywords, or 'username/agent-name' for direct slug lookup.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -54,13 +54,9 @@ class FindBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description, or look up a "
|
||||
"specific block by its ID. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, name, and description. "
|
||||
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
|
||||
"Search blocks by name or description. Returns block IDs for run_block. "
|
||||
"Always call this FIRST to get block IDs before using run_block. "
|
||||
"Then call run_block with the block's id and empty input_data to see its detailed schema."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -70,19 +66,11 @@ class FindBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description, "
|
||||
"or a block ID (UUID) for direct lookup. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
"description": "Search keywords (e.g. 'email', 'http', 'ai').",
|
||||
},
|
||||
"include_schemas": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, include full input_schema and output_schema "
|
||||
"for each block. Use when generating agent JSON that "
|
||||
"needs block schemas. Default is false."
|
||||
),
|
||||
"description": "Include full input/output schemas (for agent JSON generation).",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -19,13 +19,8 @@ class FindLibraryAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for or list agents in the user's library. Use this to find "
|
||||
"agents the user has already added to their library, including agents "
|
||||
"they created or added from the marketplace. "
|
||||
"When creating agents with sub-agent composition, use this to get "
|
||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
||||
"needed for AgentExecutorBlock nodes. "
|
||||
"Omit the query to list all agents."
|
||||
"Search user's library agents. Returns graph_id, schemas for sub-agent composition. "
|
||||
"Omit query to list all."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -35,10 +30,7 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Omit to list all agents in the library."
|
||||
),
|
||||
"description": "Search by name/description. Omit to list all.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -22,20 +22,10 @@ class FixAgentGraphTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
|
||||
"- Missing or invalid UUIDs on nodes and links\n"
|
||||
"- StoreValueBlock prerequisites for ConditionBlock\n"
|
||||
"- Double curly brace escaping in prompt templates\n"
|
||||
"- AddToList/AddToDictionary prerequisite blocks\n"
|
||||
"- CodeExecutionBlock output field naming\n"
|
||||
"- Missing credentials configuration\n"
|
||||
"- Node X coordinate spacing (800+ units apart)\n"
|
||||
"- AI model default parameters\n"
|
||||
"- Link static properties based on input schema\n"
|
||||
"- Type mismatches (inserts conversion blocks)\n\n"
|
||||
"Returns the fixed agent JSON plus a list of fixes applied. "
|
||||
"After fixing, the agent is re-validated. If still invalid, "
|
||||
"the remaining errors are included in the response."
|
||||
"Auto-fix common agent JSON issues: missing/invalid UUIDs, StoreValueBlock prerequisites, "
|
||||
"double curly brace escaping, AddToList/AddToDictionary prerequisites, credentials, "
|
||||
"node spacing, AI model defaults, link static properties, and type mismatches. "
|
||||
"Returns fixed JSON and list of fixes applied."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -42,12 +42,7 @@ class GetAgentBuildingGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Returns the complete guide for building agent JSON graphs, including "
|
||||
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
|
||||
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
|
||||
"Call this before generating agent JSON to ensure correct structure."
|
||||
)
|
||||
return "Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage). Call before generating agent JSON."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -25,8 +25,7 @@ class GetDocPageTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
"Read full documentation page content by path (from search_docs results)."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -36,10 +35,7 @@ class GetDocPageTool(BaseTool):
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
"description": "Doc file path (e.g. 'platform/block-sdk-guide.md').",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
|
||||
@@ -38,11 +38,7 @@ class GetMCPGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
|
||||
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
|
||||
"Call before using run_mcp_tool if you need a server URL or auth info."
|
||||
)
|
||||
return "Get MCP server URLs and auth guide. Call before run_mcp_tool if you need a server URL or auth info."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -88,10 +88,7 @@ class CreateFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new folder in the user's library to organize agents. "
|
||||
"Optionally nest it inside an existing folder using parent_id."
|
||||
)
|
||||
return "Create a library folder. Use parent_id to nest inside another folder."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -104,22 +101,19 @@ class CreateFolderTool(BaseTool):
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new folder (max 100 chars).",
|
||||
"description": "Folder name (max 100 chars).",
|
||||
},
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"ID of the parent folder to nest inside. "
|
||||
"Omit to create at root level."
|
||||
),
|
||||
"description": "Parent folder ID (omit for root).",
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "Optional icon identifier for the folder.",
|
||||
"description": "Icon identifier.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "Optional hex color code (#RRGGBB).",
|
||||
"description": "Hex color (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
@@ -175,13 +169,9 @@ class ListFoldersTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List the user's library folders. "
|
||||
"Omit parent_id to get the full folder tree. "
|
||||
"Provide parent_id to list only direct children of that folder. "
|
||||
"Set include_agents=true to also return the agents inside each folder "
|
||||
"and root-level agents not in any folder. Always set include_agents=true "
|
||||
"when the user asks about agents, wants to see what's in their folders, "
|
||||
"or mentions agents alongside folders."
|
||||
"List library folders. Omit parent_id for full tree. "
|
||||
"Set include_agents=true when user asks about agents, wants to see "
|
||||
"what's in their folders, or mentions agents alongside folders."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -195,17 +185,11 @@ class ListFoldersTool(BaseTool):
|
||||
"properties": {
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"List children of this folder. "
|
||||
"Omit to get the full folder tree."
|
||||
),
|
||||
"description": "List children of this folder (omit for full tree).",
|
||||
},
|
||||
"include_agents": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to include the list of agents inside each folder. "
|
||||
"Defaults to false."
|
||||
),
|
||||
"description": "Include agents in each folder (default: false).",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -357,10 +341,7 @@ class MoveFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move a folder to a different parent folder. "
|
||||
"Set target_parent_id to null to move to root level."
|
||||
)
|
||||
return "Move a folder. Set target_parent_id to null for root."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -373,14 +354,11 @@ class MoveFolderTool(BaseTool):
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the folder to move.",
|
||||
"description": "Folder ID.",
|
||||
},
|
||||
"target_parent_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"ID of the new parent folder. "
|
||||
"Use null to move to root level."
|
||||
),
|
||||
"description": "New parent folder ID (null for root).",
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
@@ -433,10 +411,7 @@ class DeleteFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a folder from the user's library. "
|
||||
"Agents inside the folder are moved to root level (not deleted)."
|
||||
)
|
||||
return "Delete a folder. Agents inside move to root (not deleted)."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -499,10 +474,7 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Move one or more agents to a folder. "
|
||||
"Set folder_id to null to move agents to root level."
|
||||
)
|
||||
return "Move agents to a folder. Set folder_id to null for root."
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -516,13 +488,11 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
"agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of library agent IDs to move.",
|
||||
"description": "Library agent IDs to move.",
|
||||
},
|
||||
"folder_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Target folder ID. Use null to move to root level."
|
||||
),
|
||||
"description": "Target folder ID (null for root).",
|
||||
},
|
||||
},
|
||||
"required": ["agent_ids"],
|
||||
|
||||
@@ -104,19 +104,11 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Run or schedule an agent from the marketplace or user's library.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
- Returns missing credentials if user needs to configure them
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if cron expression is provided
|
||||
|
||||
Identify the agent using either:
|
||||
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||
- library_agent_id: ID of an agent in the user's library
|
||||
|
||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||
return (
|
||||
"Run or schedule an agent. Automatically checks inputs and credentials. "
|
||||
"Identify by username_agent_slug ('user/agent') or library_agent_id. "
|
||||
"For scheduling, provide schedule_name + cron."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -125,40 +117,38 @@ class RunAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "Agent identifier in format 'username/agent-name'",
|
||||
"description": "Marketplace format 'username/agent-name'.",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID from user's library",
|
||||
"description": "Library agent ID.",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Input values for the agent",
|
||||
"description": "Input values for the agent.",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"use_defaults": {
|
||||
"type": "boolean",
|
||||
"description": "Set to true to run with default values (user must confirm)",
|
||||
"description": "Run with default values (confirm with user first).",
|
||||
},
|
||||
"schedule_name": {
|
||||
"type": "string",
|
||||
"description": "Name for scheduled execution (triggers scheduling mode)",
|
||||
"description": "Name for scheduled execution. Providing this triggers scheduling mode (also requires cron).",
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron expression (5 fields: min hour day month weekday)",
|
||||
"description": "Cron expression (min hour day month weekday).",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
"description": "IANA timezone (default: UTC).",
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait for execution to complete (0-300). "
|
||||
"If >0, blocks until the execution finishes or times out. "
|
||||
"Returns execution outputs when complete."
|
||||
),
|
||||
"description": "Max seconds to wait for completion (0-300).",
|
||||
"minimum": 0,
|
||||
"maximum": 300,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -45,13 +45,10 @@ class RunBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute. "
|
||||
"If a block requires human review, use continue_run_block with the "
|
||||
"review_id after the user approves."
|
||||
"Execute a block. IMPORTANT: Always get block_id from find_block first "
|
||||
"— do NOT guess or fabricate IDs. "
|
||||
"Call with empty input_data to see schema, then with data to execute. "
|
||||
"If review_required, use continue_run_block."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -61,28 +58,14 @@ class RunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"block_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's human-readable name from find_block results. "
|
||||
"Used for display purposes in the UI."
|
||||
),
|
||||
"description": "Block ID from find_block results.",
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. "
|
||||
"First call with empty {} to see the block's schema, "
|
||||
"then call again with proper values to execute."
|
||||
),
|
||||
"description": "Input values. Use {} first to see schema.",
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "block_name", "input_data"],
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -57,10 +57,9 @@ class RunMCPToolTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Two-step: (1) call with server_url to list available tools, "
|
||||
"(2) call again with server_url + tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide for known server URLs and auth details."
|
||||
"Discover and execute MCP server tools. "
|
||||
"Call with server_url only to list tools, then with tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide first for server URLs and auth."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -70,24 +69,15 @@ class RunMCPToolTool(BaseTool):
|
||||
"properties": {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"URL of the MCP server (Streamable HTTP endpoint), "
|
||||
"e.g. https://mcp.example.com/mcp"
|
||||
),
|
||||
"description": "MCP server URL (Streamable HTTP endpoint).",
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Name of the MCP tool to execute. "
|
||||
"Omit on first call to discover available tools."
|
||||
),
|
||||
"description": "Tool to execute. Omit to discover available tools.",
|
||||
},
|
||||
"tool_arguments": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Arguments to pass to the selected tool. "
|
||||
"Must match the tool's input schema returned during discovery."
|
||||
),
|
||||
"description": "Arguments matching the tool's input schema.",
|
||||
},
|
||||
},
|
||||
"required": ["server_url"],
|
||||
|
||||
@@ -38,11 +38,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
return "Search platform documentation by keyword. Use get_doc_page to read full results."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -51,10 +47,7 @@ class SearchDocsTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
"description": "Documentation search query.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Schema regression tests for all registered CoPilot tools.
|
||||
|
||||
Validates that every tool in TOOL_REGISTRY produces a well-formed schema:
|
||||
- description is non-empty
|
||||
- all `required` fields exist in `properties`
|
||||
- every property has a `type` and `description`
|
||||
- total schema character budget does not regress past threshold
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Character budget (~4 chars/token heuristic, targeting ~8000 tokens)
|
||||
_CHAR_BUDGET = 32_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def all_tool_schemas() -> list[tuple[str, Any]]:
|
||||
"""Return (tool_name, openai_schema) pairs for every registered tool."""
|
||||
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
|
||||
|
||||
|
||||
def _get_parametrize_data() -> list[tuple[str, object]]:
|
||||
"""Build parametrize data at collection time."""
|
||||
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name,schema",
|
||||
_get_parametrize_data(),
|
||||
ids=[name for name, _ in _get_parametrize_data()],
|
||||
)
|
||||
class TestToolSchema:
|
||||
"""Validate schema invariants for every registered tool."""
|
||||
|
||||
def test_description_non_empty(self, tool_name: str, schema: dict) -> None:
|
||||
desc = schema["function"].get("description", "")
|
||||
assert desc, f"Tool '{tool_name}' has an empty description"
|
||||
|
||||
def test_required_fields_exist_in_properties(
|
||||
self, tool_name: str, schema: dict
|
||||
) -> None:
|
||||
params = schema["function"].get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
for field in required:
|
||||
assert field in properties, (
|
||||
f"Tool '{tool_name}': required field '{field}' "
|
||||
f"not found in properties {list(properties.keys())}"
|
||||
)
|
||||
|
||||
def test_every_property_has_type_and_description(
|
||||
self, tool_name: str, schema: dict
|
||||
) -> None:
|
||||
params = schema["function"].get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
for prop_name, prop_def in properties.items():
|
||||
assert (
|
||||
"type" in prop_def
|
||||
), f"Tool '{tool_name}', property '{prop_name}' is missing 'type'"
|
||||
assert (
|
||||
"description" in prop_def
|
||||
), f"Tool '{tool_name}', property '{prop_name}' is missing 'description'"
|
||||
|
||||
|
||||
def test_browser_act_action_enum_complete() -> None:
|
||||
"""Assert browser_act action enum still contains all 14 supported actions.
|
||||
|
||||
This prevents future PRs from accidentally dropping actions during description
|
||||
trimming. The enum is the authoritative list — this locks it at 14 values.
|
||||
"""
|
||||
tool = TOOL_REGISTRY["browser_act"]
|
||||
schema = tool.as_openai_tool()
|
||||
fn_def = schema["function"]
|
||||
params = cast(dict[str, Any], fn_def.get("parameters", {}))
|
||||
actions = params["properties"]["action"]["enum"]
|
||||
expected = {
|
||||
"click",
|
||||
"dblclick",
|
||||
"fill",
|
||||
"type",
|
||||
"scroll",
|
||||
"hover",
|
||||
"press",
|
||||
"check",
|
||||
"uncheck",
|
||||
"select",
|
||||
"wait",
|
||||
"back",
|
||||
"forward",
|
||||
"reload",
|
||||
}
|
||||
assert set(actions) == expected, (
|
||||
f"browser_act action enum changed. Got {set(actions)}, expected {expected}. "
|
||||
"If you added/removed an action, update this test intentionally."
|
||||
)
|
||||
|
||||
|
||||
def test_total_schema_char_budget() -> None:
|
||||
"""Assert total tool schema size stays under the character budget.
|
||||
|
||||
This locks in the 34% token reduction from #12398 and prevents future
|
||||
description bloat from eroding the gains. Uses character count with a
|
||||
~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens).
|
||||
Character count is tokenizer-agnostic — no dependency on GPT or Claude
|
||||
tokenizers — while still providing a stable regression gate.
|
||||
"""
|
||||
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
|
||||
serialized = json.dumps(schemas)
|
||||
total_chars = len(serialized)
|
||||
assert total_chars < _CHAR_BUDGET, (
|
||||
f"Tool schemas use {total_chars} chars (~{total_chars // 4} tokens), "
|
||||
f"exceeding budget of {_CHAR_BUDGET} chars (~{_CHAR_BUDGET // 4} tokens). "
|
||||
f"Description bloat detected — trim descriptions or raise the budget intentionally."
|
||||
)
|
||||
@@ -22,17 +22,9 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Validate an agent JSON graph for correctness. Checks:\n"
|
||||
"- All block_ids reference real blocks\n"
|
||||
"- All links reference valid source/sink nodes and fields\n"
|
||||
"- Required input fields are wired or have defaults\n"
|
||||
"- Data types are compatible across links\n"
|
||||
"- Nested sink links use correct notation\n"
|
||||
"- Prompt templates use proper curly brace escaping\n"
|
||||
"- AgentExecutorBlock configurations are valid\n\n"
|
||||
"Call this after generating agent JSON to verify correctness. "
|
||||
"If validation fails, either fix issues manually based on the error "
|
||||
"descriptions, or call fix_agent_graph to auto-fix common problems."
|
||||
"Validate agent JSON for correctness: block_ids, links, required fields, "
|
||||
"type compatibility, nested sink notation, prompt brace escaping, "
|
||||
"and AgentExecutorBlock configs. On failure, use fix_agent_graph to auto-fix."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -46,11 +38,7 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
|
||||
"Each node needs: id (UUID), block_id, input_default, metadata. "
|
||||
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
|
||||
),
|
||||
"description": "Agent JSON with 'nodes' and 'links' arrays.",
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -59,13 +59,7 @@ class WebFetchTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Fetch the content of a public web page by URL. "
|
||||
"Returns readable text extracted from HTML by default. "
|
||||
"Useful for reading documentation, articles, and API responses. "
|
||||
"Only supports HTTP/HTTPS GET requests to public URLs "
|
||||
"(private/internal network addresses are blocked)."
|
||||
)
|
||||
return "Fetch a public web page. Public URLs only — internal addresses blocked. Returns readable text from HTML by default."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -74,14 +68,11 @@ class WebFetchTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The public HTTP/HTTPS URL to fetch.",
|
||||
"description": "Public HTTP/HTTPS URL.",
|
||||
},
|
||||
"extract_text": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true (default), extract readable text from HTML. "
|
||||
"If false, return raw content."
|
||||
),
|
||||
"description": "Extract text from HTML (default: true).",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -27,6 +27,8 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_FILE_SIZE_MB = Config().max_file_size_mb
|
||||
|
||||
# Sentinel file_id used when a tool-result file is read directly from the local
|
||||
# host filesystem (rather than from workspace storage).
|
||||
_LOCAL_TOOL_RESULT_FILE_ID = "local"
|
||||
@@ -415,13 +417,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
return "List persistent workspace files. For ephemeral session files, use SDK Glob/Read instead. Optionally filter by path prefix."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -430,24 +426,17 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
"description": "Filter by path prefix (e.g. '/documents/').",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"description": "Max files to return (default 50, max 100).",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
"description": "Include files from all sessions (default: false).",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -530,18 +519,11 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read tool instead. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Use 'save_to_path' to copy the file to the working directory "
|
||||
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
|
||||
"Use 'offset' and 'length' for paginated reads of large files "
|
||||
"(e.g., persisted tool outputs). "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
"Read a file from persistent workspace. Specify file_id or path. "
|
||||
"Small text/image files return inline; large/binary return metadata+URL. "
|
||||
"Use save_to_path to copy to working dir for processing. "
|
||||
"Use offset/length for paginated reads. "
|
||||
"Paths scoped to current session; use /sessions/<id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -551,48 +533,30 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
"description": "File ID from list_workspace_files.",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
"description": "Virtual file path (e.g. '/documents/report.pdf').",
|
||||
},
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the working "
|
||||
"directory (cloud sandbox when E2B is active, or "
|
||||
"ephemeral dir otherwise) so it can be processed with "
|
||||
"bash_exec or file tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
"description": "Copy file to this working directory path for processing.",
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
"description": "Always return metadata+URL instead of inline content.",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Character offset to start reading from (0-based). "
|
||||
"Use with 'length' for paginated reads of large files."
|
||||
),
|
||||
"description": "Character offset for paginated reads (0-based).",
|
||||
},
|
||||
"length": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of characters to return. "
|
||||
"Defaults to full file. Use with 'offset' for paginated reads."
|
||||
),
|
||||
"description": "Max characters to return for paginated reads.",
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
"required": [], # At least one of file_id or path must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -755,15 +719,10 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide content as plain text via 'content', OR base64-encoded via "
|
||||
"'content_base64', OR copy a file from the ephemeral working directory "
|
||||
"via 'source_path'. Exactly one of these three is required. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
"Write a file to persistent workspace (survives across sessions). "
|
||||
"Provide exactly one of: content (text), content_base64 (binary), "
|
||||
f"or source_path (copy from working dir). Max {_MAX_FILE_SIZE_MB}MB. "
|
||||
"Paths scoped to current session; use /sessions/<id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -773,51 +732,31 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
"description": "Filename (e.g. 'report.pdf').",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Plain text content to write. Use this for text files "
|
||||
"(code, configs, documents, etc.). "
|
||||
"Mutually exclusive with content_base64 and source_path."
|
||||
),
|
||||
"description": "Plain text content. Mutually exclusive with content_base64/source_path.",
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Base64-encoded file content. Use this for binary files "
|
||||
"(images, PDFs, etc.). "
|
||||
"Mutually exclusive with content and source_path."
|
||||
),
|
||||
"description": "Base64-encoded binary content. Mutually exclusive with content/source_path.",
|
||||
},
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to a file in the ephemeral working directory to "
|
||||
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
|
||||
"Use this to persist files created by bash_exec or SDK Write. "
|
||||
"Mutually exclusive with content and content_base64."
|
||||
),
|
||||
"description": "Working directory path to copy to workspace. Mutually exclusive with content/content_base64.",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
"description": "Virtual path (e.g. '/documents/report.pdf'). Defaults to '/{filename}'.",
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
"description": "MIME type. Auto-detected from filename if omitted.",
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
"description": "Overwrite if file exists (default: false).",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
@@ -859,10 +798,10 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
return resolved
|
||||
content: bytes = resolved
|
||||
|
||||
max_size = Config().max_file_size_mb * 1024 * 1024
|
||||
max_size = _MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
message=f"File too large. Maximum size is {_MAX_FILE_SIZE_MB}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -944,12 +883,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
return "Delete a file from persistent workspace. Specify file_id or path. Paths scoped to current session; use /sessions/<id>/... for cross-session access."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -958,17 +892,14 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
"description": "File ID from list_workspace_files.",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
"description": "Virtual file path.",
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
"required": [], # At least one of file_id or path must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -38,7 +38,7 @@ from backend.util.request import parse_url
|
||||
from .block import BlockInput
|
||||
from .db import BaseDbModel
|
||||
from .db import prisma as db
|
||||
from .db import query_raw_with_schema, transaction
|
||||
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
|
||||
from .dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
|
||||
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
|
||||
@@ -1669,16 +1669,15 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
query = f"""
|
||||
UPDATE platform."AgentNode"
|
||||
query = """
|
||||
UPDATE {schema_prefix}"AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
|
||||
"""
|
||||
AND "constantInput"->>($4)::text NOT IN """ + escaped_enum_values
|
||||
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
await execute_raw_with_schema(
|
||||
query,
|
||||
[path],
|
||||
migrate_to.value,
|
||||
id,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from .notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
clear_registry_cache,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Cache management
|
||||
"clear_registry_cache",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Read functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Pub/sub notifications for LLM registry cross-process synchronisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""Publish a refresh signal so all other workers reload their in-process cache."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.debug("Published LLM registry refresh notification")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish registry refresh notification: %s", e)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Listen for registry refresh signals and call on_refresh each time one arrives.
|
||||
|
||||
Designed to run as a long-lived background asyncio.Task. Automatically
|
||||
reconnects if the Redis connection drops.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable invoked on each refresh signal.
|
||||
Typically ``llm_registry.refresh_llm_registry``.
|
||||
"""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Dedicated connection — pub/sub must not share a connection used
|
||||
# for regular commands.
|
||||
redis_sub = AsyncRedis(
|
||||
host=HOST, port=PORT, password=PASSWORD, decode_responses=True
|
||||
)
|
||||
pubsub = redis_sub.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info("Subscribed to LLM registry refresh channel")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.debug("LLM registry refresh signal received")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in registry on_refresh callback: %s", e
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", e
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("LLM registry subscription task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"LLM registry subscription error: %s. Retrying in 5s...", e
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
254
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistryModelCost(BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
|
||||
class RegistryModelCreator(BaseModel):
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class RegistryModel(BaseModel):
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any] = {}
|
||||
extra_metadata: dict[str, Any] = {}
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = ()
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
# Typed capability fields from DB schema
|
||||
supports_tools: bool = False
|
||||
supports_json_output: bool = False
|
||||
supports_reasoning: bool = False
|
||||
supports_parallel_tool_calls: bool = False
|
||||
|
||||
|
||||
# L1 in-process cache — Redis is the shared L2 via @cached(shared_cache=True)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: # type: ignore[name-defined]
|
||||
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit),
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
"LlmModel %s has no Provider despite NOT NULL FK - "
|
||||
"falling back to providerId %s",
|
||||
record.slug,
|
||||
record.providerId,
|
||||
)
|
||||
provider_name = record.Provider.name if record.Provider else record.providerId
|
||||
provider_display = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
creator_name = record.Creator.displayName if record.Creator else "Unknown"
|
||||
|
||||
if record.priceTier not in (1, 2, 3):
|
||||
logger.warning(
|
||||
"LlmModel %s has out-of-range priceTier=%s, defaulting to 1",
|
||||
record.slug,
|
||||
record.priceTier,
|
||||
)
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
return RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
supports_tools=record.supportsTools,
|
||||
supports_json_output=record.supportsJsonOutput,
|
||||
supports_reasoning=record.supportsReasoning,
|
||||
supports_parallel_tool_calls=record.supportsParallelToolCalls,
|
||||
)
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def _fetch_registry_from_db() -> list[RegistryModel]:
|
||||
"""Fetch all LLM models from the database.
|
||||
|
||||
Results are cached in Redis (shared_cache=True) so subsequent calls within
|
||||
the TTL window skip the DB entirely — both within this process and across
|
||||
all other workers that share the same Redis instance.
|
||||
"""
|
||||
records = await prisma.models.LlmModel.prisma().find_many( # type: ignore[attr-defined]
|
||||
include={"Provider": True, "Costs": True, "Creator": True}
|
||||
)
|
||||
logger.info("Fetched %d LLM models from database", len(records))
|
||||
return [_record_to_registry_model(r) for r in records]
|
||||
|
||||
|
||||
def clear_registry_cache() -> None:
|
||||
"""Invalidate the shared Redis cache for the registry DB fetch.
|
||||
|
||||
Call this before refresh_llm_registry() after any admin DB mutation so the
|
||||
next fetch hits the database rather than serving the now-stale cached data.
|
||||
"""
|
||||
_fetch_registry_from_db.cache_clear()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the in-process L1 cache from Redis/DB.
|
||||
|
||||
On the first call (or after clear_registry_cache()), fetches fresh data
|
||||
from the database and stores it in Redis. Subsequent calls by other
|
||||
workers hit the Redis cache instead of the DB.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
models = await _fetch_registry_from_db()
|
||||
new_models = {m.slug: m for m in models}
|
||||
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
"LLM registry refreshed: %d models, %d schema options",
|
||||
len(_dynamic_models),
|
||||
len(_schema_options),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to refresh LLM registry: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
return recommended or next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""Get all model slugs for validation (enabled models only)."""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
@@ -0,0 +1,358 @@
|
||||
"""Unit tests for the LLM registry module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pydantic
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
_build_schema_options,
|
||||
_record_to_registry_model,
|
||||
get_default_model_slug,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_record(**overrides):
|
||||
"""Build a realistic mock Prisma LlmModel record."""
|
||||
provider = Mock()
|
||||
provider.name = "openai"
|
||||
provider.displayName = "OpenAI"
|
||||
|
||||
record = Mock()
|
||||
record.slug = "openai/gpt-4o"
|
||||
record.displayName = "GPT-4o"
|
||||
record.description = "Latest GPT model"
|
||||
record.providerId = "provider-uuid"
|
||||
record.Provider = provider
|
||||
record.creatorId = "creator-uuid"
|
||||
record.Creator = None
|
||||
record.contextWindow = 128000
|
||||
record.maxOutputTokens = 16384
|
||||
record.priceTier = 2
|
||||
record.isEnabled = True
|
||||
record.isRecommended = False
|
||||
record.supportsTools = True
|
||||
record.supportsJsonOutput = True
|
||||
record.supportsReasoning = False
|
||||
record.supportsParallelToolCalls = True
|
||||
record.capabilities = {}
|
||||
record.metadata = {}
|
||||
record.Costs = []
|
||||
|
||||
for key, value in overrides.items():
|
||||
setattr(record, key, value)
|
||||
return record
|
||||
|
||||
|
||||
def _make_registry_model(**kwargs) -> RegistryModel:
|
||||
"""Build a minimal RegistryModel for testing registry-level functions."""
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
defaults = dict(
|
||||
slug="openai/gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description=None,
|
||||
metadata=ModelMetadata(
|
||||
provider="openai",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
display_name="GPT-4o",
|
||||
provider_name="OpenAI",
|
||||
creator_name="Unknown",
|
||||
price_tier=2,
|
||||
),
|
||||
capabilities={},
|
||||
extra_metadata={},
|
||||
provider_display_name="OpenAI",
|
||||
is_enabled=True,
|
||||
is_recommended=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return RegistryModel(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _record_to_registry_model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_to_registry_model():
|
||||
"""Happy-path: well-formed record produces a correct RegistryModel."""
|
||||
record = _make_mock_record()
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
assert model.display_name == "GPT-4o"
|
||||
assert model.description == "Latest GPT model"
|
||||
assert model.provider_display_name == "OpenAI"
|
||||
assert model.is_enabled is True
|
||||
assert model.is_recommended is False
|
||||
assert model.supports_tools is True
|
||||
assert model.supports_json_output is True
|
||||
assert model.supports_reasoning is False
|
||||
assert model.supports_parallel_tool_calls is True
|
||||
assert model.metadata.provider == "openai"
|
||||
assert model.metadata.context_window == 128000
|
||||
assert model.metadata.max_output_tokens == 16384
|
||||
assert model.metadata.price_tier == 2
|
||||
assert model.creator is None
|
||||
assert model.costs == ()
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_provider(caplog):
|
||||
"""Record with no Provider relation falls back to providerId and logs a warning."""
|
||||
record = _make_mock_record(Provider=None, providerId="provider-uuid")
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "no Provider" in caplog.text
|
||||
assert model.metadata.provider == "provider-uuid"
|
||||
assert model.provider_display_name == "provider-uuid"
|
||||
|
||||
|
||||
def test_record_to_registry_model_missing_creator():
|
||||
"""When Creator is None, creator_name defaults to 'Unknown' and creator field is None."""
|
||||
record = _make_mock_record(Creator=None)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is None
|
||||
assert model.metadata.creator_name == "Unknown"
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_creator():
|
||||
"""When Creator is present, it is parsed into RegistryModelCreator."""
|
||||
creator_mock = Mock()
|
||||
creator_mock.id = "creator-uuid"
|
||||
creator_mock.name = "openai"
|
||||
creator_mock.displayName = "OpenAI"
|
||||
creator_mock.description = "AI company"
|
||||
creator_mock.websiteUrl = "https://openai.com"
|
||||
creator_mock.logoUrl = "https://openai.com/logo.png"
|
||||
|
||||
record = _make_mock_record(Creator=creator_mock)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.creator is not None
|
||||
assert isinstance(model.creator, RegistryModelCreator)
|
||||
assert model.creator.id == "creator-uuid"
|
||||
assert model.creator.display_name == "OpenAI"
|
||||
assert model.metadata.creator_name == "OpenAI"
|
||||
|
||||
|
||||
def test_record_to_registry_model_null_max_output_tokens():
|
||||
"""maxOutputTokens=None falls back to contextWindow."""
|
||||
record = _make_mock_record(maxOutputTokens=None, contextWindow=64000)
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert model.metadata.max_output_tokens == 64000
|
||||
|
||||
|
||||
def test_record_to_registry_model_invalid_price_tier(caplog):
|
||||
"""Out-of-range priceTier is coerced to 1 and a warning is logged."""
|
||||
record = _make_mock_record(priceTier=99)
|
||||
with caplog.at_level("WARNING"):
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert "out-of-range priceTier" in caplog.text
|
||||
assert model.metadata.price_tier == 1
|
||||
|
||||
|
||||
def test_record_to_registry_model_with_costs():
|
||||
"""Costs are parsed into RegistryModelCost tuples."""
|
||||
cost_mock = Mock()
|
||||
cost_mock.unit = "TOKENS"
|
||||
cost_mock.creditCost = 10
|
||||
cost_mock.credentialProvider = "openai"
|
||||
cost_mock.credentialId = None
|
||||
cost_mock.credentialType = None
|
||||
cost_mock.currency = "USD"
|
||||
cost_mock.metadata = {}
|
||||
|
||||
record = _make_mock_record(Costs=[cost_mock])
|
||||
model = _record_to_registry_model(record)
|
||||
|
||||
assert len(model.costs) == 1
|
||||
cost = model.costs[0]
|
||||
assert isinstance(cost, RegistryModelCost)
|
||||
assert cost.unit == "TOKENS"
|
||||
assert cost.credit_cost == 10
|
||||
assert cost.credential_provider == "openai"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_model_slug tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_default_model_slug_recommended():
|
||||
"""Recommended model is preferred over non-recommended enabled models."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-4o-recommended": _make_registry_model(
|
||||
slug="openai/gpt-4o-recommended",
|
||||
display_name="GPT-4o Recommended",
|
||||
is_recommended=True,
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result == "openai/gpt-4o-recommended"
|
||||
|
||||
|
||||
def test_get_default_model_slug_fallback():
|
||||
"""With no recommended model, falls back to first enabled (alphabetical)."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_recommended=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="GPT-3.5", is_recommended=False
|
||||
),
|
||||
}
|
||||
|
||||
result = get_default_model_slug()
|
||||
# Sorted alphabetically: GPT-3.5 < GPT-4o
|
||||
assert result == "openai/gpt-3.5"
|
||||
|
||||
|
||||
def test_get_default_model_slug_empty():
|
||||
"""Empty registry returns None."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {}
|
||||
|
||||
result = get_default_model_slug()
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_schema_options / get_schema_options tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_schema_options():
|
||||
"""Only enabled models appear, sorted case-insensitively."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(
|
||||
slug="openai/gpt-4o", display_name="GPT-4o", is_enabled=True
|
||||
),
|
||||
"openai/disabled": _make_registry_model(
|
||||
slug="openai/disabled", display_name="Disabled Model", is_enabled=False
|
||||
),
|
||||
"openai/gpt-3.5": _make_registry_model(
|
||||
slug="openai/gpt-3.5", display_name="gpt-3.5", is_enabled=True
|
||||
),
|
||||
}
|
||||
|
||||
options = _build_schema_options()
|
||||
slugs = [o["value"] for o in options]
|
||||
|
||||
# disabled model should be excluded
|
||||
assert "openai/disabled" not in slugs
|
||||
# only enabled models
|
||||
assert "openai/gpt-4o" in slugs
|
||||
assert "openai/gpt-3.5" in slugs
|
||||
# case-insensitive sort: "gpt-3.5" < "GPT-4o" (both lowercase: "gpt-3.5" < "gpt-4o")
|
||||
assert slugs.index("openai/gpt-3.5") < slugs.index("openai/gpt-4o")
|
||||
|
||||
# Verify structure
|
||||
for option in options:
|
||||
assert "label" in option
|
||||
assert "value" in option
|
||||
assert "group" in option
|
||||
assert "description" in option
|
||||
|
||||
|
||||
def test_get_schema_options_returns_copy():
|
||||
"""Mutating the returned list does not affect the internal cache."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
reg._dynamic_models = {
|
||||
"openai/gpt-4o": _make_registry_model(slug="openai/gpt-4o", display_name="GPT-4o"),
|
||||
}
|
||||
reg._schema_options = _build_schema_options()
|
||||
|
||||
options = get_schema_options()
|
||||
original_length = len(options)
|
||||
options.append({"label": "Injected", "value": "evil/model", "group": "evil", "description": ""})
|
||||
|
||||
# Internal state should be unchanged
|
||||
assert len(get_schema_options()) == original_length
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic frozen model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_model_frozen():
|
||||
"""Pydantic frozen=True should reject attribute assignment."""
|
||||
model = _make_registry_model()
|
||||
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
model.slug = "changed/slug" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_registry_model_cost_frozen():
|
||||
"""RegistryModelCost is also frozen."""
|
||||
cost = RegistryModelCost(
|
||||
unit="TOKENS",
|
||||
credit_cost=5,
|
||||
credential_provider="openai",
|
||||
)
|
||||
with pytest.raises((pydantic.ValidationError, TypeError)):
|
||||
cost.unit = "RUN" # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# refresh_llm_registry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_llm_registry():
|
||||
"""Mock prisma find_many, verify cache is populated after refresh."""
|
||||
import backend.data.llm_registry.registry as reg
|
||||
|
||||
record = _make_mock_record()
|
||||
mock_find_many = AsyncMock(return_value=[record])
|
||||
|
||||
with patch("prisma.models.LlmModel.prisma") as mock_prisma_cls:
|
||||
mock_prisma_instance = Mock()
|
||||
mock_prisma_instance.find_many = mock_find_many
|
||||
mock_prisma_cls.return_value = mock_prisma_instance
|
||||
|
||||
# Clear state first
|
||||
reg._dynamic_models = {}
|
||||
reg._schema_options = []
|
||||
|
||||
await refresh_llm_registry()
|
||||
|
||||
assert "openai/gpt-4o" in reg._dynamic_models
|
||||
model = reg._dynamic_models["openai/gpt-4o"]
|
||||
assert isinstance(model, RegistryModel)
|
||||
assert model.slug == "openai/gpt-4o"
|
||||
# Schema options should be populated too
|
||||
assert len(reg._schema_options) == 1
|
||||
assert reg._schema_options[0]["value"] == "openai/gpt-4o"
|
||||
@@ -224,7 +224,7 @@ async def execute_node(
|
||||
# Sanity check: validate the execution input.
|
||||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||||
if input_data is None:
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
log_metadata.warning(f"Skip execution, input validation error: {error}")
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""LLM registry API (public + admin)."""
|
||||
|
||||
from .admin_routes import router as admin_router
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router", "admin_router"]
|
||||
115
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
115
autogpt_platform/backend/backend/server/v2/llm/admin_model.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Request/response models for LLM registry admin API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateLlmProviderRequest(BaseModel):
|
||||
"""Request model for creating an LLM provider."""
|
||||
|
||||
name: str = Field(
|
||||
..., description="Provider identifier (e.g., 'openai', 'anthropic')"
|
||||
)
|
||||
display_name: str = Field(..., description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmProviderRequest(BaseModel):
|
||||
"""Request model for updating an LLM provider."""
|
||||
|
||||
display_name: str | None = Field(None, description="Human-readable provider name")
|
||||
description: str | None = Field(None, description="Provider description")
|
||||
default_credential_provider: str | None = Field(
|
||||
None, description="Default credential system identifier"
|
||||
)
|
||||
default_credential_id: str | None = Field(None, description="Default credential ID")
|
||||
default_credential_type: str | None = Field(
|
||||
None, description="Default credential type"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
|
||||
|
||||
class CreateLlmModelRequest(BaseModel):
|
||||
"""Request model for creating an LLM model."""
|
||||
|
||||
slug: str = Field(..., description="Model slug (e.g., 'gpt-4', 'claude-3-opus')")
|
||||
display_name: str = Field(..., description="Human-readable model name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
provider_id: str = Field(..., description="Provider ID (UUID)")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int = Field(
|
||||
..., description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int = Field(
|
||||
..., description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="Whether the model is enabled")
|
||||
is_recommended: bool = Field(
|
||||
default=False, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool = Field(default=False, description="Supports function calling")
|
||||
supports_json_output: bool = Field(
|
||||
default=False, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool = Field(
|
||||
default=False, description="Supports reasoning mode"
|
||||
)
|
||||
supports_parallel_tool_calls: bool = Field(
|
||||
default=False, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
costs: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Cost entries for the model"
|
||||
)
|
||||
|
||||
|
||||
class UpdateLlmModelRequest(BaseModel):
|
||||
"""Request model for updating an LLM model."""
|
||||
|
||||
display_name: str | None = Field(None, description="Human-readable model name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
creator_id: str | None = Field(None, description="Creator ID (UUID)")
|
||||
context_window: int | None = Field(
|
||||
None, description="Maximum context window in tokens", gt=0
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
None, description="Maximum output tokens (None if unlimited)", gt=0
|
||||
)
|
||||
price_tier: int | None = Field(
|
||||
None, description="Price tier (1=cheapest, 2=medium, 3=expensive)", ge=1, le=3
|
||||
)
|
||||
is_enabled: bool | None = Field(None, description="Whether the model is enabled")
|
||||
is_recommended: bool | None = Field(
|
||||
None, description="Whether the model is recommended"
|
||||
)
|
||||
supports_tools: bool | None = Field(None, description="Supports function calling")
|
||||
supports_json_output: bool | None = Field(
|
||||
None, description="Supports JSON output mode"
|
||||
)
|
||||
supports_reasoning: bool | None = Field(None, description="Supports reasoning mode")
|
||||
supports_parallel_tool_calls: bool | None = Field(
|
||||
None, description="Supports parallel tool calls"
|
||||
)
|
||||
capabilities: dict[str, Any] | None = Field(
|
||||
None, description="Additional capabilities"
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(None, description="Additional metadata")
|
||||
689
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
689
autogpt_platform/backend/backend/server/v2/llm/admin_routes.py
Normal file
@@ -0,0 +1,689 @@
|
||||
"""Admin API for LLM registry management.
|
||||
|
||||
Provides endpoints for:
|
||||
- Reading creators (GET)
|
||||
- Creating, updating, and deleting models
|
||||
- Creating, updating, and deleting providers
|
||||
|
||||
All endpoints require admin authentication. Mutations refresh the registry cache.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import autogpt_libs.auth
|
||||
from fastapi import APIRouter, HTTPException, Security, status
|
||||
|
||||
from backend.server.v2.llm import db_write
|
||||
from backend.server.v2.llm.admin_model import (
|
||||
CreateLlmModelRequest,
|
||||
CreateLlmProviderRequest,
|
||||
UpdateLlmModelRequest,
|
||||
UpdateLlmProviderRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _map_provider_response(provider: Any) -> dict[str, Any]:
|
||||
"""Map Prisma provider model to response dict."""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.displayName,
|
||||
"description": provider.description,
|
||||
"default_credential_provider": provider.defaultCredentialProvider,
|
||||
"default_credential_id": provider.defaultCredentialId,
|
||||
"default_credential_type": provider.defaultCredentialType,
|
||||
"metadata": dict(provider.metadata or {}),
|
||||
"created_at": provider.createdAt.isoformat() if provider.createdAt else None,
|
||||
"updated_at": provider.updatedAt.isoformat() if provider.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_model_response(model: Any) -> dict[str, Any]:
|
||||
"""Map Prisma model to response dict."""
|
||||
return {
|
||||
"id": model.id,
|
||||
"slug": model.slug,
|
||||
"display_name": model.displayName,
|
||||
"description": model.description,
|
||||
"provider_id": model.providerId,
|
||||
"creator_id": model.creatorId,
|
||||
"context_window": model.contextWindow,
|
||||
"max_output_tokens": model.maxOutputTokens,
|
||||
"price_tier": model.priceTier,
|
||||
"is_enabled": model.isEnabled,
|
||||
"is_recommended": model.isRecommended,
|
||||
"supports_tools": model.supportsTools,
|
||||
"supports_json_output": model.supportsJsonOutput,
|
||||
"supports_reasoning": model.supportsReasoning,
|
||||
"supports_parallel_tool_calls": model.supportsParallelToolCalls,
|
||||
"capabilities": dict(model.capabilities or {}),
|
||||
"metadata": dict(model.metadata or {}),
|
||||
"created_at": model.createdAt.isoformat() if model.createdAt else None,
|
||||
"updated_at": model.updatedAt.isoformat() if model.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
def _map_creator_response(creator: Any) -> dict[str, Any]:
|
||||
"""Map Prisma creator model to response dict."""
|
||||
return {
|
||||
"id": creator.id,
|
||||
"name": creator.name,
|
||||
"display_name": creator.displayName,
|
||||
"description": creator.description,
|
||||
"website_url": creator.websiteUrl,
|
||||
"logo_url": creator.logoUrl,
|
||||
"metadata": dict(creator.metadata or {}),
|
||||
"created_at": creator.createdAt.isoformat() if creator.createdAt else None,
|
||||
"updated_at": creator.updatedAt.isoformat() if creator.updatedAt else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_model(
|
||||
request: CreateLlmModelRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM model.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models as pm
|
||||
|
||||
# Resolve provider name to ID
|
||||
provider = await pm.LlmProvider.prisma().find_unique(
|
||||
where={"name": request.provider_id}
|
||||
)
|
||||
if not provider:
|
||||
# Try as UUID fallback
|
||||
provider = await pm.LlmProvider.prisma().find_unique(
|
||||
where={"id": request.provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Provider '{request.provider_id}' not found",
|
||||
)
|
||||
|
||||
model = await db_write.create_model(
|
||||
slug=request.slug,
|
||||
display_name=request.display_name,
|
||||
provider_id=provider.id,
|
||||
context_window=request.context_window,
|
||||
price_tier=request.price_tier,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
# Create costs if provided in the raw request body
|
||||
if hasattr(request, 'costs') and request.costs:
|
||||
for cost_input in request.costs:
|
||||
await pm.LlmModelCost.prisma().create(
|
||||
data={
|
||||
"unit": cost_input.get("unit", "RUN"),
|
||||
"creditCost": int(cost_input.get("credit_cost", 1)),
|
||||
"credentialProvider": provider.name,
|
||||
"metadata": prisma.Json(cost_input.get("metadata", {})),
|
||||
"Model": {"connect": {"id": model.id}},
|
||||
}
|
||||
)
|
||||
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created model '{request.slug}' (id: {model.id})")
|
||||
|
||||
# Re-fetch with costs included
|
||||
model = await pm.LlmModel.prisma().find_unique(
|
||||
where={"id": model.id},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create model")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_model(
|
||||
slug: str,
|
||||
request: UpdateLlmModelRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
# Find model by slug first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
model = await db_write.update_model(
|
||||
model_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
creator_id=request.creator_id,
|
||||
context_window=request.context_window,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
price_tier=request.price_tier,
|
||||
is_enabled=request.is_enabled,
|
||||
is_recommended=request.is_recommended,
|
||||
supports_tools=request.supports_tools,
|
||||
supports_json_output=request.supports_json_output,
|
||||
supports_reasoning=request.supports_reasoning,
|
||||
supports_parallel_tool_calls=request.supports_parallel_tool_calls,
|
||||
capabilities=request.capabilities,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated model '{slug}' (id: {model.id})")
|
||||
return _map_model_response(model)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update model")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/models/{slug:path}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_model(
|
||||
slug: str,
|
||||
replacement_model_slug: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model with optional migration.
|
||||
|
||||
If workflows are using this model and no replacement_model_slug is given,
|
||||
returns 400 with the node count. Provide replacement_model_slug to migrate
|
||||
affected nodes before deletion.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.delete_model(
|
||||
model_id=existing.id,
|
||||
replacement_model_slug=replacement_model_slug,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Deleted model '{slug}' (migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/models/{slug:path}/usage",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many workflow nodes reference it."""
|
||||
try:
|
||||
return await db_write.get_model_usage(slug)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get model usage: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get model usage")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/models/{slug:path}/toggle",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def toggle_model(
|
||||
slug: str,
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status with optional migration when disabling.
|
||||
|
||||
Body params:
|
||||
is_enabled: bool
|
||||
migrate_to_slug: optional str
|
||||
migration_reason: optional str
|
||||
custom_credit_cost: optional int
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": slug}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model with slug '{slug}' not found"
|
||||
)
|
||||
|
||||
result = await db_write.toggle_model_with_migration(
|
||||
model_id=existing.id,
|
||||
is_enabled=request.get("is_enabled", True),
|
||||
migrate_to_slug=request.get("migrate_to_slug"),
|
||||
migration_reason=request.get("migration_reason"),
|
||||
custom_credit_cost=request.get("custom_credit_cost"),
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Toggled model '{slug}' enabled={request.get('is_enabled')} "
|
||||
f"(migrated {result['nodes_migrated']} nodes)"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Model toggle failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to toggle model: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to toggle model")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/migrations",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""List model migrations."""
|
||||
try:
|
||||
migrations = await db_write.list_migrations(
|
||||
include_reverted=include_reverted
|
||||
)
|
||||
return {"migrations": migrations}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list migrations: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list migrations"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/migrations/{migration_id}/revert",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes."""
|
||||
try:
|
||||
result = await db_write.revert_migration(
|
||||
migration_id=migration_id,
|
||||
re_enable_source_model=re_enable_source_model,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(
|
||||
f"Reverted migration {migration_id}: "
|
||||
f"{result['nodes_reverted']} nodes restored"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning(f"Migration revert failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to revert migration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to revert migration"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/providers",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_provider(
|
||||
request: CreateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
provider = await db_write.create_provider(
|
||||
name=request.name,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Created provider '{request.name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider creation validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create provider")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/providers/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_provider(
|
||||
name: str,
|
||||
request: UpdateLlmProviderRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
provider = await db_write.update_provider(
|
||||
provider_id=existing.id,
|
||||
display_name=request.display_name,
|
||||
description=request.description,
|
||||
default_credential_provider=request.default_credential_provider,
|
||||
default_credential_id=request.default_credential_id,
|
||||
default_credential_type=request.default_credential_type,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Updated provider '{name}' (id: {provider.id})")
|
||||
return _map_provider_response(provider)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider update validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update provider")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/providers/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_provider(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
Requires admin authentication.
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
try:
|
||||
# Find provider by name first to get ID
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider with name '{name}' not found"
|
||||
)
|
||||
|
||||
await db_write.delete_provider(provider_id=existing.id)
|
||||
await db_write.refresh_runtime_caches()
|
||||
logger.info(f"Deleted provider '{name}' (id: {existing.id})")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Provider deletion validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete provider: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete provider")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/providers",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_providers() -> dict[str, Any]:
|
||||
"""List all LLM providers from the database.
|
||||
|
||||
Unlike the public endpoint, this returns ALL providers including
|
||||
those with no models. Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
providers = await prisma.models.LlmProvider.prisma().find_many(
|
||||
order={"name": "asc"},
|
||||
include={"Models": True},
|
||||
)
|
||||
return {
|
||||
"providers": [
|
||||
{**_map_provider_response(p), "model_count": len(p.Models) if p.Models else 0}
|
||||
for p in providers
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list providers: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list providers")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/admin/models",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def admin_list_models(
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
enabled_only: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""List all LLM models from the database.
|
||||
|
||||
Unlike the public endpoint, this returns full model data including
|
||||
costs and creator info. Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
where = {"isEnabled": True} if enabled_only else {}
|
||||
models = await prisma.models.LlmModel.prisma().find_many(
|
||||
where=where,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
order={"displayName": "asc"},
|
||||
include={"Costs": True, "Creator": True},
|
||||
)
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
**_map_model_response(m),
|
||||
"creator": _map_creator_response(m.Creator) if m.Creator else None,
|
||||
"costs": [
|
||||
{
|
||||
"unit": c.unit,
|
||||
"credit_cost": float(c.creditCost),
|
||||
"credential_provider": c.credentialProvider,
|
||||
"credential_type": c.credentialType,
|
||||
"metadata": dict(c.metadata or {}),
|
||||
}
|
||||
for c in (m.Costs or [])
|
||||
],
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list models: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list models")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/llm/creators",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def list_creators() -> dict[str, Any]:
|
||||
"""List all LLM model creators.
|
||||
|
||||
Requires admin authentication.
|
||||
"""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
creators = await prisma.models.LlmModelCreator.prisma().find_many(
|
||||
order={"name": "asc"}
|
||||
)
|
||||
logger.info(f"Retrieved {len(creators)} creators")
|
||||
return {"creators": [_map_creator_response(c) for c in creators]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list creators: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list creators")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm/creators",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def create_creator(
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().create(
|
||||
data={
|
||||
"name": request["name"],
|
||||
"displayName": request["display_name"],
|
||||
"description": request.get("description"),
|
||||
"websiteUrl": request.get("website_url"),
|
||||
"logoUrl": request.get("logo_url"),
|
||||
"metadata": prisma.Json(request.get("metadata", {})),
|
||||
}
|
||||
)
|
||||
logger.info(f"Created creator '{creator.name}' (id: {creator.id})")
|
||||
return _map_creator_response(creator)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/llm/creators/{name}",
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def update_creator(
|
||||
name: str,
|
||||
request: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name}
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Creator '{name}' not found"
|
||||
)
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
if "display_name" in request:
|
||||
data["displayName"] = request["display_name"]
|
||||
if "description" in request:
|
||||
data["description"] = request["description"]
|
||||
if "website_url" in request:
|
||||
data["websiteUrl"] = request["website_url"]
|
||||
if "logo_url" in request:
|
||||
data["logoUrl"] = request["logo_url"]
|
||||
|
||||
creator = await prisma.models.LlmModelCreator.prisma().update(
|
||||
where={"id": existing.id},
|
||||
data=data,
|
||||
)
|
||||
logger.info(f"Updated creator '{name}' (id: {creator.id})")
|
||||
return _map_creator_response(creator)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to update creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/llm/creators/{name}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
async def delete_creator(
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete an LLM model creator."""
|
||||
try:
|
||||
import prisma.models
|
||||
|
||||
existing = await prisma.models.LlmModelCreator.prisma().find_unique(
|
||||
where={"name": name},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Creator '{name}' not found"
|
||||
)
|
||||
|
||||
if existing.Models and len(existing.Models) > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot delete creator '{name}' — it has {len(existing.Models)} associated models",
|
||||
)
|
||||
|
||||
await prisma.models.LlmModelCreator.prisma().delete(
|
||||
where={"id": existing.id}
|
||||
)
|
||||
logger.info(f"Deleted creator '{name}' (id: {existing.id})")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete creator: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
588
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
588
autogpt_platform/backend/backend/server/v2/llm/db_write.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""Database write operations for LLM registry admin API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
import prisma.models
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.db import transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_provider_data(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build provider data dict for Prisma operations."""
|
||||
return {
|
||||
"name": name,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"defaultCredentialProvider": default_credential_provider,
|
||||
"defaultCredentialId": default_credential_id,
|
||||
"defaultCredentialType": default_credential_type,
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
def _build_model_data(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build model data dict for Prisma operations."""
|
||||
data: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"displayName": display_name,
|
||||
"description": description,
|
||||
"Provider": {"connect": {"id": provider_id}},
|
||||
"contextWindow": context_window,
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
"priceTier": price_tier,
|
||||
"isEnabled": is_enabled,
|
||||
"isRecommended": is_recommended,
|
||||
"supportsTools": supports_tools,
|
||||
"supportsJsonOutput": supports_json_output,
|
||||
"supportsReasoning": supports_reasoning,
|
||||
"supportsParallelToolCalls": supports_parallel_tool_calls,
|
||||
"capabilities": prisma.Json(capabilities or {}),
|
||||
"metadata": prisma.Json(metadata or {}),
|
||||
}
|
||||
if creator_id:
|
||||
data["Creator"] = {"connect": {"id": creator_id}}
|
||||
return data
|
||||
|
||||
|
||||
async def create_provider(
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Create a new LLM provider."""
|
||||
data = _build_provider_data(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
default_credential_provider=default_credential_provider,
|
||||
default_credential_id=default_credential_id,
|
||||
default_credential_type=default_credential_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
provider = await prisma.models.LlmProvider.prisma().create(
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError("Failed to create provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
default_credential_provider: str | None = None,
|
||||
default_credential_id: str | None = None,
|
||||
default_credential_type: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmProvider:
|
||||
"""Update an existing LLM provider."""
|
||||
# Fetch existing provider to get current name
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id}
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if default_credential_provider is not None:
|
||||
data["defaultCredentialProvider"] = default_credential_provider
|
||||
if default_credential_id is not None:
|
||||
data["defaultCredentialId"] = default_credential_id
|
||||
if default_credential_type is not None:
|
||||
data["defaultCredentialType"] = default_credential_type
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
|
||||
updated = await prisma.models.LlmProvider.prisma().update(
|
||||
where={"id": provider_id},
|
||||
data=data,
|
||||
include={"Models": True},
|
||||
)
|
||||
if not updated:
|
||||
raise ValueError("Failed to update provider")
|
||||
return updated
|
||||
|
||||
|
||||
async def delete_provider(provider_id: str) -> bool:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
"""
|
||||
# Check if provider exists
|
||||
provider = await prisma.models.LlmProvider.prisma().find_unique(
|
||||
where={"id": provider_id},
|
||||
include={"Models": True},
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider with id '{provider_id}' not found")
|
||||
|
||||
# Check if provider has any models
|
||||
model_count = len(provider.Models) if provider.Models else 0
|
||||
if model_count > 0:
|
||||
raise ValueError(
|
||||
f"Cannot delete provider '{provider.displayName}' because it has "
|
||||
f"{model_count} model(s). Delete all models first."
|
||||
)
|
||||
|
||||
await prisma.models.LlmProvider.prisma().delete(where={"id": provider_id})
|
||||
return True
|
||||
|
||||
|
||||
async def create_model(
|
||||
slug: str,
|
||||
display_name: str,
|
||||
provider_id: str,
|
||||
context_window: int,
|
||||
price_tier: int,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
is_enabled: bool = True,
|
||||
is_recommended: bool = False,
|
||||
supports_tools: bool = False,
|
||||
supports_json_output: bool = False,
|
||||
supports_reasoning: bool = False,
|
||||
supports_parallel_tool_calls: bool = False,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Create a new LLM model."""
|
||||
data = _build_model_data(
|
||||
slug=slug,
|
||||
display_name=display_name,
|
||||
provider_id=provider_id,
|
||||
context_window=context_window,
|
||||
price_tier=price_tier,
|
||||
description=description,
|
||||
creator_id=creator_id,
|
||||
max_output_tokens=max_output_tokens,
|
||||
is_enabled=is_enabled,
|
||||
is_recommended=is_recommended,
|
||||
supports_tools=supports_tools,
|
||||
supports_json_output=supports_json_output,
|
||||
supports_reasoning=supports_reasoning,
|
||||
supports_parallel_tool_calls=supports_parallel_tool_calls,
|
||||
capabilities=capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
model = await prisma.models.LlmModel.prisma().create(
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
if not model:
|
||||
raise ValueError("Failed to create model")
|
||||
return model
|
||||
|
||||
|
||||
async def update_model(
|
||||
model_id: str,
|
||||
display_name: str | None = None,
|
||||
description: str | None = None,
|
||||
creator_id: str | None = None,
|
||||
context_window: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
price_tier: int | None = None,
|
||||
is_enabled: bool | None = None,
|
||||
is_recommended: bool | None = None,
|
||||
supports_tools: bool | None = None,
|
||||
supports_json_output: bool | None = None,
|
||||
supports_reasoning: bool | None = None,
|
||||
supports_parallel_tool_calls: bool | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> prisma.models.LlmModel:
|
||||
"""Update an existing LLM model.
|
||||
|
||||
When is_recommended=True, clears the flag on all other models first so
|
||||
only one model can be recommended at a time.
|
||||
"""
|
||||
# Build update data (only include fields that are provided)
|
||||
data: dict[str, Any] = {}
|
||||
if display_name is not None:
|
||||
data["displayName"] = display_name
|
||||
if description is not None:
|
||||
data["description"] = description
|
||||
if context_window is not None:
|
||||
data["contextWindow"] = context_window
|
||||
if max_output_tokens is not None:
|
||||
data["maxOutputTokens"] = max_output_tokens
|
||||
if price_tier is not None:
|
||||
data["priceTier"] = price_tier
|
||||
if is_enabled is not None:
|
||||
data["isEnabled"] = is_enabled
|
||||
if is_recommended is not None:
|
||||
data["isRecommended"] = is_recommended
|
||||
if supports_tools is not None:
|
||||
data["supportsTools"] = supports_tools
|
||||
if supports_json_output is not None:
|
||||
data["supportsJsonOutput"] = supports_json_output
|
||||
if supports_reasoning is not None:
|
||||
data["supportsReasoning"] = supports_reasoning
|
||||
if supports_parallel_tool_calls is not None:
|
||||
data["supportsParallelToolCalls"] = supports_parallel_tool_calls
|
||||
if capabilities is not None:
|
||||
data["capabilities"] = prisma.Json(capabilities)
|
||||
if metadata is not None:
|
||||
data["metadata"] = prisma.Json(metadata)
|
||||
if creator_id is not None:
|
||||
data["creatorId"] = creator_id if creator_id else None
|
||||
|
||||
async with transaction() as tx:
|
||||
# Enforce single recommended model: unset all others first.
|
||||
if is_recommended is True:
|
||||
await tx.llmmodel.update_many(
|
||||
where={"id": {"not": model_id}},
|
||||
data={"isRecommended": False},
|
||||
)
|
||||
|
||||
model = await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data=data,
|
||||
include={"Costs": True, "Creator": True, "Provider": True},
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_usage(slug: str) -> dict[str, Any]:
|
||||
"""Get usage count for a model — how many AgentNodes reference it."""
|
||||
import prisma as prisma_module
|
||||
|
||||
count_result = await prisma_module.get_client().query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
slug,
|
||||
)
|
||||
node_count = int(count_result[0]["count"]) if count_result else 0
|
||||
return {"model_slug": slug, "node_count": node_count}
|
||||
|
||||
|
||||
async def toggle_model_with_migration(
|
||||
model_id: str,
|
||||
is_enabled: bool,
|
||||
migrate_to_slug: str | None = None,
|
||||
migration_reason: str | None = None,
|
||||
custom_credit_cost: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Toggle a model's enabled status, optionally migrating workflows when disabling."""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
nodes_migrated = 0
|
||||
migration_id: str | None = None
|
||||
|
||||
if not is_enabled and migrate_to_slug:
|
||||
async with transaction() as tx:
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": migrate_to_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{migrate_to_slug}' is disabled. "
|
||||
f"Please enable it before using it as a replacement."
|
||||
)
|
||||
|
||||
node_ids_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT id
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
FOR UPDATE
|
||||
""",
|
||||
model.slug,
|
||||
)
|
||||
migrated_node_ids = (
|
||||
[row["id"] for row in node_ids_result] if node_ids_result else []
|
||||
)
|
||||
nodes_migrated = len(migrated_node_ids)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
""",
|
||||
migrate_to_slug,
|
||||
node_ids_json,
|
||||
)
|
||||
|
||||
await tx.llmmodel.update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
if nodes_migrated > 0:
|
||||
migration_record = await tx.llmmodelmigration.create(
|
||||
data={
|
||||
"sourceModelSlug": model.slug,
|
||||
"targetModelSlug": migrate_to_slug,
|
||||
"reason": migration_reason,
|
||||
"migratedNodeIds": json.dumps(migrated_node_ids),
|
||||
"nodeCount": nodes_migrated,
|
||||
"customCreditCost": custom_credit_cost,
|
||||
}
|
||||
)
|
||||
migration_id = migration_record.id
|
||||
else:
|
||||
await prisma.models.LlmModel.prisma().update(
|
||||
where={"id": model_id},
|
||||
data={"isEnabled": is_enabled},
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes_migrated": nodes_migrated,
|
||||
"migrated_to_slug": migrate_to_slug if nodes_migrated > 0 else None,
|
||||
"migration_id": migration_id,
|
||||
}
|
||||
|
||||
|
||||
async def delete_model(
|
||||
model_id: str, replacement_model_slug: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM model, optionally migrating affected AgentNodes first.
|
||||
|
||||
If workflows are using this model and no replacement is given, raises ValueError.
|
||||
If replacement is given, atomically migrates all affected nodes then deletes.
|
||||
"""
|
||||
model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"id": model_id}, include={"Costs": True}
|
||||
)
|
||||
if not model:
|
||||
raise ValueError(f"Model with id '{model_id}' not found")
|
||||
|
||||
deleted_slug = model.slug
|
||||
deleted_display_name = model.displayName
|
||||
|
||||
async with transaction() as tx:
|
||||
count_result = await tx.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "AgentNode"
|
||||
WHERE "constantInput"::jsonb->>'model' = $1
|
||||
""",
|
||||
deleted_slug,
|
||||
)
|
||||
nodes_to_migrate = int(count_result[0]["count"]) if count_result else 0
|
||||
|
||||
if nodes_to_migrate > 0:
|
||||
if not replacement_model_slug:
|
||||
raise ValueError(
|
||||
f"Cannot delete model '{deleted_slug}': {nodes_to_migrate} workflow node(s) "
|
||||
f"are using it. Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
replacement = await tx.llmmodel.find_unique(
|
||||
where={"slug": replacement_model_slug}
|
||||
)
|
||||
if not replacement:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' not found"
|
||||
)
|
||||
if not replacement.isEnabled:
|
||||
raise ValueError(
|
||||
f"Replacement model '{replacement_model_slug}' is disabled."
|
||||
)
|
||||
|
||||
await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = $2
|
||||
""",
|
||||
replacement_model_slug,
|
||||
deleted_slug,
|
||||
)
|
||||
|
||||
await tx.llmmodel.delete(where={"id": model_id})
|
||||
|
||||
return {
|
||||
"deleted_model_slug": deleted_slug,
|
||||
"deleted_model_display_name": deleted_display_name,
|
||||
"replacement_model_slug": replacement_model_slug,
|
||||
"nodes_migrated": nodes_to_migrate,
|
||||
}
|
||||
|
||||
|
||||
async def list_migrations(
|
||||
include_reverted: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List model migrations."""
|
||||
where: Any = None if include_reverted else {"isReverted": False}
|
||||
records = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where=where,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"source_model_slug": r.sourceModelSlug,
|
||||
"target_model_slug": r.targetModelSlug,
|
||||
"reason": r.reason,
|
||||
"node_count": r.nodeCount,
|
||||
"custom_credit_cost": r.customCreditCost,
|
||||
"is_reverted": r.isReverted,
|
||||
"reverted_at": r.revertedAt.isoformat() if r.revertedAt else None,
|
||||
"created_at": r.createdAt.isoformat(),
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
|
||||
async def revert_migration(
|
||||
migration_id: str,
|
||||
re_enable_source_model: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Revert a model migration, restoring affected nodes to their original model."""
|
||||
migration = await prisma.models.LlmModelMigration.prisma().find_unique(
|
||||
where={"id": migration_id}
|
||||
)
|
||||
if not migration:
|
||||
raise ValueError(f"Migration with id '{migration_id}' not found")
|
||||
|
||||
if migration.isReverted:
|
||||
raise ValueError(
|
||||
f"Migration '{migration_id}' has already been reverted"
|
||||
)
|
||||
|
||||
source_model = await prisma.models.LlmModel.prisma().find_unique(
|
||||
where={"slug": migration.sourceModelSlug}
|
||||
)
|
||||
if not source_model:
|
||||
raise ValueError(
|
||||
f"Source model '{migration.sourceModelSlug}' no longer exists."
|
||||
)
|
||||
|
||||
migrated_node_ids: list[str] = (
|
||||
migration.migratedNodeIds
|
||||
if isinstance(migration.migratedNodeIds, list)
|
||||
else json.loads(migration.migratedNodeIds) # type: ignore
|
||||
)
|
||||
if not migrated_node_ids:
|
||||
raise ValueError("No nodes to revert in this migration")
|
||||
|
||||
source_model_re_enabled = False
|
||||
|
||||
async with transaction() as tx:
|
||||
if not source_model.isEnabled and re_enable_source_model:
|
||||
await tx.llmmodel.update(
|
||||
where={"id": source_model.id},
|
||||
data={"isEnabled": True},
|
||||
)
|
||||
source_model_re_enabled = True
|
||||
|
||||
node_ids_json = json.dumps(migrated_node_ids)
|
||||
result = await tx.execute_raw(
|
||||
"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
to_jsonb($1::text)
|
||||
)
|
||||
WHERE id::text IN (
|
||||
SELECT jsonb_array_elements_text($2::jsonb)
|
||||
)
|
||||
AND "constantInput"::jsonb->>'model' = $3
|
||||
""",
|
||||
migration.sourceModelSlug,
|
||||
node_ids_json,
|
||||
migration.targetModelSlug,
|
||||
)
|
||||
nodes_reverted = result if isinstance(result, int) else 0
|
||||
|
||||
await tx.llmmodelmigration.update(
|
||||
where={"id": migration_id},
|
||||
data={
|
||||
"isReverted": True,
|
||||
"revertedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"migration_id": migration_id,
|
||||
"source_model_slug": migration.sourceModelSlug,
|
||||
"target_model_slug": migration.targetModelSlug,
|
||||
"nodes_reverted": nodes_reverted,
|
||||
"nodes_already_changed": len(migrated_node_ids) - nodes_reverted,
|
||||
"source_model_re_enabled": source_model_re_enabled,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_runtime_caches() -> None:
|
||||
"""Invalidate the shared Redis cache, refresh this process, notify other workers."""
|
||||
from backend.data.llm_registry.notifications import (
|
||||
publish_registry_refresh_notification,
|
||||
)
|
||||
|
||||
# Invalidate Redis so the next fetch hits the DB.
|
||||
llm_registry.clear_registry_cache()
|
||||
# Refresh this process (also repopulates Redis via @cached(shared_cache=True)).
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Tell other workers to reload their in-process cache from the fresh Redis data.
|
||||
await publish_registry_refresh_notification()
|
||||
68
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
68
autogpt_platform/backend/backend/server/v2/llm/model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Pydantic models for LLM registry public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LlmModelCost(pydantic.BaseModel):
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
unit: str # "RUN" or "TOKENS"
|
||||
credit_cost: int = pydantic.Field(ge=0)
|
||||
credential_provider: str
|
||||
credential_id: str | None = None
|
||||
credential_type: str | None = None
|
||||
currency: str | None = None
|
||||
metadata: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class LlmModelCreator(pydantic.BaseModel):
|
||||
"""Represents the organization that created/trained the model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
website_url: str | None = None
|
||||
logo_url: str | None = None
|
||||
|
||||
|
||||
class LlmModel(pydantic.BaseModel):
|
||||
"""Public-facing LLM model information."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None = None
|
||||
provider_name: str
|
||||
creator: LlmModelCreator | None = None
|
||||
context_window: int
|
||||
max_output_tokens: int | None = None
|
||||
price_tier: int # 1=cheapest, 2=medium, 3=expensive
|
||||
is_enabled: bool = True
|
||||
is_recommended: bool = False
|
||||
capabilities: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
costs: list[LlmModelCost] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmProvider(pydantic.BaseModel):
|
||||
"""Provider with its enabled models."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
models: list[LlmModel] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class LlmModelsResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/models."""
|
||||
|
||||
models: list[LlmModel]
|
||||
total: int
|
||||
|
||||
|
||||
class LlmProvidersResponse(pydantic.BaseModel):
|
||||
"""Response for GET /llm/providers."""
|
||||
|
||||
providers: list[LlmProvider]
|
||||
143
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
143
autogpt_platform/backend/backend/server/v2/llm/routes.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Public read-only API for LLM registry."""
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data.llm_registry import (
|
||||
RegistryModelCreator,
|
||||
get_all_models,
|
||||
get_enabled_models,
|
||||
)
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
prefix="/llm",
|
||||
tags=["llm"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _map_creator(
|
||||
creator: RegistryModelCreator | None,
|
||||
) -> llm_model.LlmModelCreator | None:
|
||||
"""Convert registry creator to API model."""
|
||||
if not creator:
|
||||
return None
|
||||
return llm_model.LlmModelCreator(
|
||||
id=creator.id,
|
||||
name=creator.name,
|
||||
display_name=creator.display_name,
|
||||
description=creator.description,
|
||||
website_url=creator.website_url,
|
||||
logo_url=creator.logo_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=llm_model.LlmModelsResponse)
|
||||
async def list_models(
|
||||
enabled_only: bool = fastapi.Query(
|
||||
default=True, description="Only return enabled models"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all LLM models available to users.
|
||||
|
||||
Returns models from the in-memory registry cache.
|
||||
Use enabled_only=true to filter to only enabled models (default).
|
||||
"""
|
||||
# Get models from in-memory registry
|
||||
registry_models = get_enabled_models() if enabled_only else get_all_models()
|
||||
|
||||
# Map to API response models
|
||||
models = [
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_enabled=model.is_enabled,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in registry_models
|
||||
]
|
||||
|
||||
return llm_model.LlmModelsResponse(models=models, total=len(models))
|
||||
|
||||
|
||||
@router.get("/providers", response_model=llm_model.LlmProvidersResponse)
|
||||
async def list_providers():
|
||||
"""
|
||||
List all LLM providers with their enabled models.
|
||||
|
||||
Groups enabled models by provider from the in-memory registry.
|
||||
"""
|
||||
# Get all enabled models and group by provider
|
||||
registry_models = get_enabled_models()
|
||||
|
||||
# Group models by provider
|
||||
provider_map: dict[str, list] = {}
|
||||
for model in registry_models:
|
||||
provider_key = model.metadata.provider
|
||||
if provider_key not in provider_map:
|
||||
provider_map[provider_key] = []
|
||||
provider_map[provider_key].append(model)
|
||||
|
||||
# Build provider responses
|
||||
providers = []
|
||||
for provider_key, models in sorted(provider_map.items()):
|
||||
# Use the first model's provider display name
|
||||
display_name = models[0].provider_display_name if models else provider_key
|
||||
|
||||
providers.append(
|
||||
llm_model.LlmProvider(
|
||||
name=provider_key,
|
||||
display_name=display_name,
|
||||
models=[
|
||||
llm_model.LlmModel(
|
||||
slug=model.slug,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
provider_name=model.provider_display_name,
|
||||
creator=_map_creator(model.creator),
|
||||
context_window=model.metadata.context_window,
|
||||
max_output_tokens=model.metadata.max_output_tokens,
|
||||
price_tier=model.metadata.price_tier,
|
||||
is_enabled=model.is_enabled,
|
||||
is_recommended=model.is_recommended,
|
||||
capabilities=model.capabilities,
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
unit=cost.unit,
|
||||
credit_cost=cost.credit_cost,
|
||||
credential_provider=cost.credential_provider,
|
||||
credential_id=cost.credential_id,
|
||||
credential_type=cost.credential_type,
|
||||
currency=cost.currency,
|
||||
metadata=cost.metadata,
|
||||
)
|
||||
for cost in model.costs
|
||||
],
|
||||
)
|
||||
for model in sorted(models, key=lambda m: m.display_name)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
@@ -704,8 +704,19 @@ def get_service_client(
|
||||
return kwargs
|
||||
|
||||
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
|
||||
"""Validate and coerce the RPC result to the expected return type.
|
||||
|
||||
Falls back to the raw result with a warning if validation fails.
|
||||
"""
|
||||
if expected_return:
|
||||
return expected_return.validate_python(result)
|
||||
try:
|
||||
return expected_return.validate_python(result)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"RPC return type validation failed, using raw result: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
return result
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
|
||||
@@ -302,7 +302,14 @@ def _value_satisfies_type(value: Any, target: Any) -> bool:
|
||||
|
||||
# Simple type (e.g. str, int)
|
||||
if isinstance(target, type):
|
||||
return isinstance(value, target)
|
||||
try:
|
||||
return isinstance(value, target)
|
||||
except TypeError:
|
||||
# TypedDict and some typing constructs don't support isinstance checks.
|
||||
# For TypedDict, check if value is a dict with the required keys.
|
||||
if isinstance(value, dict) and hasattr(target, "__required_keys__"):
|
||||
return all(k in value for k in target.__required_keys__)
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "LlmCostUnit" AS ENUM ('RUN', 'TOKENS');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmProvider" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"defaultCredentialProvider" TEXT,
|
||||
"defaultCredentialId" TEXT,
|
||||
"defaultCredentialType" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmProvider_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCreator" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"websiteUrl" TEXT,
|
||||
"logoUrl" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModelCreator_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModel" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"slug" TEXT NOT NULL,
|
||||
"displayName" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"providerId" TEXT NOT NULL,
|
||||
"creatorId" TEXT,
|
||||
"contextWindow" INTEGER NOT NULL,
|
||||
"maxOutputTokens" INTEGER,
|
||||
"priceTier" INTEGER NOT NULL DEFAULT 1,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isRecommended" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsTools" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsJsonOutput" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsReasoning" BOOLEAN NOT NULL DEFAULT false,
|
||||
"supportsParallelToolCalls" BOOLEAN NOT NULL DEFAULT false,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "LlmModel_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelCost" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"unit" "LlmCostUnit" NOT NULL DEFAULT 'RUN',
|
||||
"creditCost" INTEGER NOT NULL,
|
||||
"credentialProvider" TEXT NOT NULL,
|
||||
"credentialId" TEXT,
|
||||
"credentialType" TEXT,
|
||||
"currency" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"llmModelId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "LlmModelCost_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LlmModelMigration" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"sourceModelSlug" TEXT NOT NULL,
|
||||
"targetModelSlug" TEXT NOT NULL,
|
||||
"reason" TEXT,
|
||||
"migratedNodeIds" JSONB NOT NULL DEFAULT '[]',
|
||||
"nodeCount" INTEGER NOT NULL,
|
||||
"customCreditCost" INTEGER,
|
||||
"isReverted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"revertedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "LlmModelMigration_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmProvider_name_key" ON "LlmProvider"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModelCreator_name_key" ON "LlmModelCreator"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LlmModel_slug_key" ON "LlmModel"("slug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_providerId_isEnabled_idx" ON "LlmModel"("providerId", "isEnabled");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModel_creatorId_idx" ON "LlmModel"("creatorId");
|
||||
|
||||
-- CreateIndex (partial unique for default costs - no specific credential)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_default_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL;
|
||||
|
||||
-- CreateIndex (partial unique for credential-specific costs)
|
||||
CREATE UNIQUE INDEX "LlmModelCost_credential_cost_key" ON "LlmModelCost"("llmModelId", "credentialProvider", "credentialId", "unit") WHERE "credentialId" IS NOT NULL;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_targetModelSlug_idx" ON "LlmModelMigration"("targetModelSlug");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LlmModelMigration_sourceModelSlug_isReverted_idx" ON "LlmModelMigration"("sourceModelSlug", "isReverted");
|
||||
|
||||
-- CreateIndex (partial unique to prevent multiple active migrations per source)
|
||||
CREATE UNIQUE INDEX "LlmModelMigration_active_source_key" ON "LlmModelMigration"("sourceModelSlug") WHERE "isReverted" = false;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_providerId_fkey" FOREIGN KEY ("providerId") REFERENCES "LlmProvider"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModel" ADD CONSTRAINT "LlmModel_creatorId_fkey" FOREIGN KEY ("creatorId") REFERENCES "LlmModelCreator"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelCost" ADD CONSTRAINT "LlmModelCost_llmModelId_fkey" FOREIGN KEY ("llmModelId") REFERENCES "LlmModel"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_sourceModelSlug_fkey" FOREIGN KEY ("sourceModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LlmModelMigration" ADD CONSTRAINT "LlmModelMigration_targetModelSlug_fkey" FOREIGN KEY ("targetModelSlug") REFERENCES "LlmModel"("slug") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddCheckConstraints (enforce data integrity)
|
||||
ALTER TABLE "LlmModel"
|
||||
ADD CONSTRAINT "LlmModel_priceTier_check" CHECK ("priceTier" BETWEEN 1 AND 3);
|
||||
|
||||
ALTER TABLE "LlmModelCost"
|
||||
ADD CONSTRAINT "LlmModelCost_creditCost_check" CHECK ("creditCost" >= 0);
|
||||
|
||||
ALTER TABLE "LlmModelMigration"
|
||||
ADD CONSTRAINT "LlmModelMigration_nodeCount_check" CHECK ("nodeCount" >= 0),
|
||||
ADD CONSTRAINT "LlmModelMigration_customCreditCost_check" CHECK ("customCreditCost" IS NULL OR "customCreditCost" >= 0);
|
||||
@@ -0,0 +1,287 @@
|
||||
-- Seed LLM Registry from existing hard-coded data
|
||||
-- This migration populates the LlmProvider, LlmModelCreator, LlmModel, and LlmModelCost tables
|
||||
-- with data from the existing MODEL_METADATA and MODEL_COST dictionaries
|
||||
|
||||
-- Insert Providers
|
||||
INSERT INTO "LlmProvider" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "defaultCredentialProvider", "defaultCredentialType", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'OpenAI language models', 'openai', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Anthropic Claude models', 'anthropic', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'groq', 'Groq', 'Groq inference API', 'groq', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'open_router', 'OpenRouter', 'OpenRouter unified API', 'open_router', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'aiml_api', 'AI/ML API', 'AI/ML API models', 'aiml_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'ollama', 'Ollama', 'Ollama local models', 'ollama', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'llama_api', 'Llama API', 'Llama API models', 'llama_api', 'api_key', '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'v0', 'v0', 'v0 by Vercel models', 'v0', 'api_key', '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Model Creators
|
||||
INSERT INTO "LlmModelCreator" ("id", "createdAt", "updatedAt", "name", "displayName", "description", "websiteUrl", "logoUrl", "metadata")
|
||||
VALUES
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'openai', 'OpenAI', 'Creator of GPT, O1, O3, and DALL-E models', 'https://openai.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'anthropic', 'Anthropic', 'Creator of Claude AI models', 'https://anthropic.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'meta', 'Meta', 'Creator of Llama foundation models', 'https://llama.meta.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'google', 'Google', 'Creator of Gemini and PaLM models', 'https://deepmind.google', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'mistralai', 'Mistral AI', 'Creator of Mistral and Codestral models', 'https://mistral.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'cohere', 'Cohere', 'Creator of Command language models', 'https://cohere.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'deepseek', 'DeepSeek', 'Creator of DeepSeek reasoning models', 'https://deepseek.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'alibaba', 'Alibaba', 'Creator of Qwen language models', 'https://qwenlm.github.io', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nvidia', 'NVIDIA', 'Creator of Nemotron models', 'https://nvidia.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'vercel', 'Vercel', 'Creator of v0 AI models', 'https://v0.dev', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'microsoft', 'Microsoft', 'Creator of Phi models', 'https://microsoft.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'xai', 'xAI', 'Creator of Grok models', 'https://x.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'perplexity', 'Perplexity AI', 'Creator of Sonar search models', 'https://perplexity.ai', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'nousresearch', 'Nous Research', 'Creator of Hermes language models', 'https://nousresearch.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'amazon', 'Amazon', 'Creator of Nova language models', 'https://aws.amazon.com', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'gryphe', 'Gryphe', 'Creator of MythoMax models', 'https://huggingface.co/Gryphe', NULL, '{}'::jsonb),
|
||||
(gen_random_uuid(), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 'moonshotai', 'Moonshot AI', 'Creator of Kimi language models', 'https://moonshot.ai', NULL, '{}'::jsonb)
|
||||
ON CONFLICT ("name") DO NOTHING;
|
||||
|
||||
-- Insert Models (using CTEs to reference provider and creator IDs)
|
||||
WITH provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
),
|
||||
creator_ids AS (
|
||||
SELECT "id", "name" FROM "LlmModelCreator"
|
||||
)
|
||||
INSERT INTO "LlmModel" ("id", "createdAt", "updatedAt", "slug", "displayName", "description", "providerId", "creatorId", "contextWindow", "maxOutputTokens", "isEnabled", "capabilities", "metadata")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
model_slug,
|
||||
model_display_name,
|
||||
NULL,
|
||||
p."id",
|
||||
c."id",
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
true,
|
||||
'{}'::jsonb,
|
||||
'{}'::jsonb
|
||||
FROM (VALUES
|
||||
-- OpenAI models (creator: openai)
|
||||
('o3-2025-04-16', 'O3', 'openai', 'openai', 200000, 100000),
|
||||
('o3-mini', 'O3 Mini', 'openai', 'openai', 200000, 100000),
|
||||
('o1', 'O1', 'openai', 'openai', 200000, 100000),
|
||||
('o1-mini', 'O1 Mini', 'openai', 'openai', 128000, 65536),
|
||||
('gpt-5.2-2025-12-11', 'GPT-5.2', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-2025-08-07', 'GPT 5', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5.1-2025-11-13', 'GPT 5.1', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-mini-2025-08-07', 'GPT 5 Mini', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-nano-2025-08-07', 'GPT 5 Nano', 'openai', 'openai', 400000, 128000),
|
||||
('gpt-5-chat-latest', 'GPT 5 Chat', 'openai', 'openai', 400000, 16384),
|
||||
('gpt-4.1-2025-04-14', 'GPT 4.1', 'openai', 'openai', 1000000, 32768),
|
||||
('gpt-4.1-mini-2025-04-14', 'GPT 4.1 Mini', 'openai', 'openai', 1047576, 32768),
|
||||
('gpt-4o-mini', 'GPT 4o Mini', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4o', 'GPT 4o', 'openai', 'openai', 128000, 16384),
|
||||
('gpt-4-turbo', 'GPT 4 Turbo', 'openai', 'openai', 128000, 4096),
|
||||
-- Anthropic models (creator: anthropic)
|
||||
('claude-opus-4-6', 'Claude Opus 4.6', 'anthropic', 'anthropic', 200000, 128000),
|
||||
('claude-sonnet-4-6', 'Claude Sonnet 4.6', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-1-20250805', 'Claude 4.1 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-opus-4-20250514', 'Claude 4 Opus', 'anthropic', 'anthropic', 200000, 32000),
|
||||
('claude-sonnet-4-20250514', 'Claude 4 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-opus-4-5-20251101', 'Claude 4.5 Opus', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-sonnet-4-5-20250929', 'Claude 4.5 Sonnet', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-haiku-4-5-20251001', 'Claude 4.5 Haiku', 'anthropic', 'anthropic', 200000, 64000),
|
||||
('claude-3-haiku-20240307', 'Claude 3 Haiku', 'anthropic', 'anthropic', 200000, 4096),
|
||||
-- AI/ML API models (creators: alibaba, nvidia, meta)
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 'Qwen 2.5 72B', 'aiml_api', 'alibaba', 32000, 8000),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 'Llama 3.1 Nemotron 70B', 'aiml_api', 'nvidia', 128000, 40000),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 'Llama 3.3 70B', 'aiml_api', 'meta', 128000, NULL),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'Meta Llama 3.1 70B', 'aiml_api', 'meta', 131000, 2000),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 'Llama 3.2 3B', 'aiml_api', 'meta', 128000, NULL),
|
||||
-- Groq models (creator: meta for Llama)
|
||||
('llama-3.3-70b-versatile', 'Llama 3.3 70B', 'groq', 'meta', 128000, 32768),
|
||||
('llama-3.1-8b-instant', 'Llama 3.1 8B', 'groq', 'meta', 128000, 8192),
|
||||
-- Ollama models (creators: meta for Llama, mistralai for Mistral)
|
||||
('llama3.3', 'Llama 3.3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.2', 'Llama 3.2', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3', 'Llama 3', 'ollama', 'meta', 8192, NULL),
|
||||
('llama3.1:405b', 'Llama 3.1 405B', 'ollama', 'meta', 8192, NULL),
|
||||
('dolphin-mistral:latest', 'Dolphin Mistral', 'ollama', 'mistralai', 32768, NULL),
|
||||
-- OpenRouter models (creators: google, mistralai, cohere, deepseek, perplexity, nousresearch, openai, amazon, microsoft, gryphe, meta, xai, moonshotai, alibaba)
|
||||
('google/gemini-2.5-pro-preview-03-25', 'Gemini 2.5 Pro', 'open_router', 'google', 1050000, 8192),
|
||||
('google/gemini-2.5-pro', 'Gemini 2.5 Pro', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3.1-pro-preview', 'Gemini 3.1 Pro Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-3-flash-preview', 'Gemini 3 Flash Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash', 'Gemini 2.5 Flash', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-001', 'Gemini 2.0 Flash', 'open_router', 'google', 1048576, 8192),
|
||||
('google/gemini-3.1-flash-lite-preview', 'Gemini 3.1 Flash Lite Preview', 'open_router', 'google', 1048576, 65536),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 'Gemini 2.5 Flash Lite Preview', 'open_router', 'google', 1048576, 65535),
|
||||
('google/gemini-2.0-flash-lite-001', 'Gemini 2.0 Flash Lite', 'open_router', 'google', 1048576, 8192),
|
||||
('mistralai/mistral-nemo', 'Mistral Nemo', 'open_router', 'mistralai', 128000, 4096),
|
||||
('mistralai/mistral-large-2512', 'Mistral Large 3 2512', 'open_router', 'mistralai', 262144, NULL),
|
||||
('mistralai/mistral-medium-3.1', 'Mistral Medium 3.1', 'open_router', 'mistralai', 131072, NULL),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 'Mistral Small 3.2 24B', 'open_router', 'mistralai', 131072, 131072),
|
||||
('mistralai/codestral-2508', 'Codestral 2508', 'open_router', 'mistralai', 256000, NULL),
|
||||
('cohere/command-r-08-2024', 'Command R', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-r-plus-08-2024', 'Command R Plus', 'open_router', 'cohere', 128000, 4096),
|
||||
('cohere/command-a-03-2025', 'Command A 03.2025', 'open_router', 'cohere', 256000, 8192),
|
||||
('cohere/command-a-reasoning-08-2025', 'Command A Reasoning 08.2025', 'open_router', 'cohere', 256000, 32768),
|
||||
('cohere/command-a-translate-08-2025', 'Command A Translate 08.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('cohere/command-a-vision-07-2025', 'Command A Vision 07.2025', 'open_router', 'cohere', 128000, 8192),
|
||||
('deepseek/deepseek-chat', 'DeepSeek Chat', 'open_router', 'deepseek', 64000, 2048),
|
||||
('deepseek/deepseek-r1-0528', 'DeepSeek R1', 'open_router', 'deepseek', 163840, 163840),
|
||||
('perplexity/sonar', 'Perplexity Sonar', 'open_router', 'perplexity', 127000, 8000),
|
||||
('perplexity/sonar-pro', 'Perplexity Sonar Pro', 'open_router', 'perplexity', 200000, 8000),
|
||||
('perplexity/sonar-deep-research', 'Perplexity Sonar Deep Research', 'open_router', 'perplexity', 128000, 16000),
|
||||
('perplexity/sonar-reasoning-pro', 'Sonar Reasoning Pro', 'open_router', 'perplexity', 128000, 8000),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 'Hermes 3 Llama 3.1 405B', 'open_router', 'nousresearch', 131000, 4096),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 'Hermes 3 Llama 3.1 70B', 'open_router', 'nousresearch', 12288, 12288),
|
||||
('openai/gpt-oss-120b', 'GPT OSS 120B', 'open_router', 'openai', 131072, 131072),
|
||||
('openai/gpt-oss-20b', 'GPT OSS 20B', 'open_router', 'openai', 131072, 32768),
|
||||
('amazon/nova-lite-v1', 'Amazon Nova Lite', 'open_router', 'amazon', 300000, 5120),
|
||||
('amazon/nova-micro-v1', 'Amazon Nova Micro', 'open_router', 'amazon', 128000, 5120),
|
||||
('amazon/nova-pro-v1', 'Amazon Nova Pro', 'open_router', 'amazon', 300000, 5120),
|
||||
('microsoft/wizardlm-2-8x22b', 'WizardLM 2 8x22B', 'open_router', 'microsoft', 65536, 4096),
|
||||
('microsoft/phi-4', 'Phi-4', 'open_router', 'microsoft', 16384, 16384),
|
||||
('gryphe/mythomax-l2-13b', 'MythoMax L2 13B', 'open_router', 'gryphe', 4096, 4096),
|
||||
('meta-llama/llama-4-scout', 'Llama 4 Scout', 'open_router', 'meta', 131072, 131072),
|
||||
('meta-llama/llama-4-maverick', 'Llama 4 Maverick', 'open_router', 'meta', 1048576, 1000000),
|
||||
('x-ai/grok-3', 'Grok 3', 'open_router', 'xai', 131072, 131072),
|
||||
('x-ai/grok-4', 'Grok 4', 'open_router', 'xai', 256000, 256000),
|
||||
('x-ai/grok-4-fast', 'Grok 4 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-4.1-fast', 'Grok 4.1 Fast', 'open_router', 'xai', 2000000, 30000),
|
||||
('x-ai/grok-code-fast-1', 'Grok Code Fast 1', 'open_router', 'xai', 256000, 10000),
|
||||
('moonshotai/kimi-k2', 'Kimi K2', 'open_router', 'moonshotai', 131000, 131000),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 'Qwen 3 235B Thinking', 'open_router', 'alibaba', 262144, 262144),
|
||||
('qwen/qwen3-coder', 'Qwen 3 Coder', 'open_router', 'alibaba', 262144, 262144),
|
||||
-- Llama API models (creator: meta)
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 'Llama 4 Scout', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 'Llama 4 Maverick', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-8B-Instruct', 'Llama 3.3 8B', 'llama_api', 'meta', 128000, 4028),
|
||||
('Llama-3.3-70B-Instruct', 'Llama 3.3 70B', 'llama_api', 'meta', 128000, 4028),
|
||||
-- v0 models (creator: vercel)
|
||||
('v0-1.5-md', 'v0 1.5 MD', 'v0', 'vercel', 128000, 64000),
|
||||
('v0-1.5-lg', 'v0 1.5 LG', 'v0', 'vercel', 512000, 64000),
|
||||
('v0-1.0-md', 'v0 1.0 MD', 'v0', 'vercel', 128000, 64000)
|
||||
) AS models(model_slug, model_display_name, provider_name, creator_name, context_window, max_output_tokens)
|
||||
JOIN provider_ids p ON p."name" = models.provider_name
|
||||
JOIN creator_ids c ON c."name" = models.creator_name
|
||||
ON CONFLICT ("slug") DO NOTHING;
|
||||
|
||||
-- Insert Costs (using CTEs to reference model IDs)
|
||||
WITH model_ids AS (
|
||||
SELECT "id", "slug", "providerId" FROM "LlmModel"
|
||||
),
|
||||
provider_ids AS (
|
||||
SELECT "id", "name" FROM "LlmProvider"
|
||||
)
|
||||
INSERT INTO "LlmModelCost" ("id", "createdAt", "updatedAt", "unit", "creditCost", "credentialProvider", "credentialId", "credentialType", "currency", "metadata", "llmModelId")
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
CURRENT_TIMESTAMP,
|
||||
CURRENT_TIMESTAMP,
|
||||
'RUN'::"LlmCostUnit",
|
||||
cost,
|
||||
p."name",
|
||||
NULL,
|
||||
'api_key',
|
||||
NULL,
|
||||
'{}'::jsonb,
|
||||
m."id"
|
||||
FROM (VALUES
|
||||
-- OpenAI costs
|
||||
('o3-2025-04-16', 4),
|
||||
('o3-mini', 2),
|
||||
('o1', 16),
|
||||
('o1-mini', 4),
|
||||
('gpt-5.2-2025-12-11', 5),
|
||||
('gpt-5-2025-08-07', 2),
|
||||
('gpt-5.1-2025-11-13', 5),
|
||||
('gpt-5-mini-2025-08-07', 1),
|
||||
('gpt-5-nano-2025-08-07', 1),
|
||||
('gpt-5-chat-latest', 5),
|
||||
('gpt-4.1-2025-04-14', 2),
|
||||
('gpt-4.1-mini-2025-04-14', 1),
|
||||
('gpt-4o-mini', 1),
|
||||
('gpt-4o', 3),
|
||||
('gpt-4-turbo', 10),
|
||||
-- Anthropic costs
|
||||
('claude-opus-4-6', 21),
|
||||
('claude-sonnet-4-6', 5),
|
||||
('claude-opus-4-1-20250805', 21),
|
||||
('claude-opus-4-20250514', 21),
|
||||
('claude-sonnet-4-20250514', 5),
|
||||
('claude-haiku-4-5-20251001', 4),
|
||||
('claude-opus-4-5-20251101', 14),
|
||||
('claude-sonnet-4-5-20250929', 9),
|
||||
('claude-3-haiku-20240307', 1),
|
||||
-- AI/ML API costs
|
||||
('Qwen/Qwen2.5-72B-Instruct-Turbo', 1),
|
||||
('nvidia/llama-3.1-nemotron-70b-instruct', 1),
|
||||
('meta-llama/Llama-3.3-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 1),
|
||||
('meta-llama/Llama-3.2-3B-Instruct-Turbo', 1),
|
||||
-- Groq costs
|
||||
('llama-3.3-70b-versatile', 1),
|
||||
('llama-3.1-8b-instant', 1),
|
||||
-- Ollama costs
|
||||
('llama3.3', 1),
|
||||
('llama3.2', 1),
|
||||
('llama3', 1),
|
||||
('llama3.1:405b', 1),
|
||||
('dolphin-mistral:latest', 1),
|
||||
-- OpenRouter costs
|
||||
('google/gemini-2.5-pro-preview-03-25', 4),
|
||||
('google/gemini-2.5-pro', 4),
|
||||
('google/gemini-3.1-pro-preview', 5),
|
||||
('google/gemini-3-flash-preview', 3),
|
||||
('google/gemini-3.1-flash-lite-preview', 1),
|
||||
('mistralai/mistral-nemo', 1),
|
||||
('mistralai/mistral-large-2512', 3),
|
||||
('mistralai/mistral-medium-3.1', 2),
|
||||
('mistralai/mistral-small-3.2-24b-instruct', 1),
|
||||
('mistralai/codestral-2508', 2),
|
||||
('cohere/command-r-08-2024', 1),
|
||||
('cohere/command-r-plus-08-2024', 3),
|
||||
('cohere/command-a-03-2025', 2),
|
||||
('cohere/command-a-reasoning-08-2025', 3),
|
||||
('cohere/command-a-translate-08-2025', 1),
|
||||
('cohere/command-a-vision-07-2025', 2),
|
||||
('deepseek/deepseek-chat', 2),
|
||||
('perplexity/sonar', 1),
|
||||
('perplexity/sonar-pro', 5),
|
||||
('perplexity/sonar-deep-research', 10),
|
||||
('perplexity/sonar-reasoning-pro', 5),
|
||||
('nousresearch/hermes-3-llama-3.1-405b', 1),
|
||||
('nousresearch/hermes-3-llama-3.1-70b', 1),
|
||||
('amazon/nova-lite-v1', 1),
|
||||
('amazon/nova-micro-v1', 1),
|
||||
('amazon/nova-pro-v1', 1),
|
||||
('microsoft/wizardlm-2-8x22b', 1),
|
||||
('microsoft/phi-4', 1),
|
||||
('gryphe/mythomax-l2-13b', 1),
|
||||
('meta-llama/llama-4-scout', 1),
|
||||
('meta-llama/llama-4-maverick', 1),
|
||||
('x-ai/grok-3', 5),
|
||||
('x-ai/grok-4', 9),
|
||||
('x-ai/grok-4-fast', 1),
|
||||
('x-ai/grok-4.1-fast', 1),
|
||||
('x-ai/grok-code-fast-1', 1),
|
||||
('moonshotai/kimi-k2', 1),
|
||||
('qwen/qwen3-235b-a22b-thinking-2507', 1),
|
||||
('qwen/qwen3-coder', 9),
|
||||
('google/gemini-2.5-flash', 1),
|
||||
('google/gemini-2.0-flash-001', 1),
|
||||
('google/gemini-2.5-flash-lite-preview-06-17', 1),
|
||||
('google/gemini-2.0-flash-lite-001', 1),
|
||||
('deepseek/deepseek-r1-0528', 1),
|
||||
('openai/gpt-oss-120b', 1),
|
||||
('openai/gpt-oss-20b', 1),
|
||||
-- Llama API costs
|
||||
('Llama-4-Scout-17B-16E-Instruct-FP8', 1),
|
||||
('Llama-4-Maverick-17B-128E-Instruct-FP8', 1),
|
||||
('Llama-3.3-8B-Instruct', 1),
|
||||
('Llama-3.3-70B-Instruct', 1),
|
||||
-- v0 costs
|
||||
('v0-1.5-md', 1),
|
||||
('v0-1.5-lg', 2),
|
||||
('v0-1.0-md', 1)
|
||||
) AS costs(model_slug, cost)
|
||||
JOIN model_ids m ON m."slug" = costs.model_slug
|
||||
JOIN provider_ids p ON p."id" = m."providerId"
|
||||
ON CONFLICT ("llmModelId", "credentialProvider", "unit") WHERE "credentialId" IS NULL DO NOTHING;
|
||||
|
||||
@@ -1301,3 +1301,164 @@ model OAuthRefreshToken {
|
||||
@@index([userId, applicationId])
|
||||
@@index([expiresAt]) // For cleanup
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLM Registry Models
|
||||
// ============================================================================
|
||||
|
||||
enum LlmCostUnit {
|
||||
RUN
|
||||
TOKENS
|
||||
}
|
||||
|
||||
model LlmProvider {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
defaultCredentialProvider String?
|
||||
defaultCredentialId String?
|
||||
defaultCredentialType String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModel {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
slug String @unique
|
||||
displayName String
|
||||
description String?
|
||||
|
||||
providerId String
|
||||
Provider LlmProvider @relation(fields: [providerId], references: [id], onDelete: Restrict)
|
||||
|
||||
// Creator is the organization that created/trained the model (e.g., OpenAI, Meta)
|
||||
// This is distinct from the provider who hosts/serves the model (e.g., OpenRouter)
|
||||
creatorId String?
|
||||
Creator LlmModelCreator? @relation(fields: [creatorId], references: [id], onDelete: SetNull)
|
||||
|
||||
contextWindow Int
|
||||
maxOutputTokens Int?
|
||||
priceTier Int @default(1) // 1=cheapest, 2=medium, 3=expensive (DB constraint: 1-3)
|
||||
isEnabled Boolean @default(true)
|
||||
isRecommended Boolean @default(false)
|
||||
|
||||
// Model-specific capabilities
|
||||
// These vary per model even within the same provider (e.g., Hugging Face)
|
||||
// Default to false for safety - partially-seeded rows should not be assumed capable
|
||||
supportsTools Boolean @default(false)
|
||||
supportsJsonOutput Boolean @default(false)
|
||||
supportsReasoning Boolean @default(false)
|
||||
supportsParallelToolCalls Boolean @default(false)
|
||||
|
||||
capabilities Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
|
||||
Costs LlmModelCost[]
|
||||
SourceMigrations LlmModelMigration[] @relation("SourceMigrations")
|
||||
TargetMigrations LlmModelMigration[] @relation("TargetMigrations")
|
||||
|
||||
@@index([providerId, isEnabled])
|
||||
@@index([creatorId])
|
||||
// Note: slug already has @unique which creates an implicit index
|
||||
}
|
||||
|
||||
model LlmModelCost {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
unit LlmCostUnit @default(RUN)
|
||||
|
||||
creditCost Int // DB constraint: >= 0
|
||||
|
||||
// Provider identifier (e.g., "openai", "anthropic", "openrouter")
|
||||
// Used to determine which credential system provides the API key.
|
||||
// Allows different pricing for:
|
||||
// - Default provider costs (WHERE credentialId IS NULL)
|
||||
// - User's own API key costs (WHERE credentialId IS NOT NULL)
|
||||
credentialProvider String
|
||||
credentialId String?
|
||||
credentialType String?
|
||||
currency String?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
llmModelId String
|
||||
Model LlmModel @relation(fields: [llmModelId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Note: Unique constraints are implemented as partial indexes in migration SQL:
|
||||
// - One for default costs (WHERE credentialId IS NULL)
|
||||
// - One for credential-specific costs (WHERE credentialId IS NOT NULL)
|
||||
// This allows both provider-level defaults and credential-specific overrides
|
||||
}
|
||||
|
||||
model LlmModelCreator {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
name String @unique // e.g., "openai", "anthropic", "meta"
|
||||
displayName String // e.g., "OpenAI", "Anthropic", "Meta"
|
||||
description String?
|
||||
websiteUrl String? // Link to creator's website
|
||||
logoUrl String? // URL to creator's logo
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
Models LlmModel[]
|
||||
|
||||
}
|
||||
|
||||
model LlmModelMigration {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
sourceModelSlug String // The original model that was disabled
|
||||
targetModelSlug String // The model workflows were migrated to
|
||||
reason String? // Why the migration happened (e.g., "Provider outage")
|
||||
|
||||
// FK constraints ensure slugs reference valid models
|
||||
SourceModel LlmModel @relation("SourceMigrations", fields: [sourceModelSlug], references: [slug], onDelete: Restrict)
|
||||
TargetModel LlmModel @relation("TargetMigrations", fields: [targetModelSlug], references: [slug], onDelete: Restrict)
|
||||
|
||||
// Track affected nodes as JSON array of node IDs
|
||||
// Format: ["node-uuid-1", "node-uuid-2", ...]
|
||||
migratedNodeIds Json @default("[]")
|
||||
nodeCount Int // Number of nodes migrated (DB constraint: >= 0)
|
||||
|
||||
// Custom pricing override for migrated workflows during the migration period.
|
||||
// Use case: When migrating users from an expensive model (e.g., GPT-4) to a cheaper
|
||||
// one (e.g., GPT-3.5), you may want to temporarily maintain the original pricing
|
||||
// to avoid billing surprises, or offer a discount during the transition.
|
||||
//
|
||||
// IMPORTANT: This field is intended for integration with the billing system.
|
||||
// When billing calculates costs for nodes affected by this migration, it should
|
||||
// check if customCreditCost is set and use it instead of the target model's cost.
|
||||
// If null, the target model's normal cost applies.
|
||||
//
|
||||
// TODO: Integrate with billing system to apply this override during cost calculation.
|
||||
// LIMITATION: This is a simple Int and doesn't distinguish RUN vs TOKENS pricing.
|
||||
// For token-priced models, this may be ambiguous. Consider migrating to a relation
|
||||
// with LlmModelCost or a dedicated override model in a follow-up PR.
|
||||
customCreditCost Int? // DB constraint: >= 0 when not null
|
||||
|
||||
// Revert tracking
|
||||
isReverted Boolean @default(false)
|
||||
revertedAt DateTime?
|
||||
|
||||
// Note: Partial unique index in migration SQL prevents multiple active migrations per source:
|
||||
// UNIQUE (sourceModelSlug) WHERE isReverted = false
|
||||
@@index([targetModelSlug])
|
||||
@@index([sourceModelSlug, isReverted]) // Composite index for active migration queries
|
||||
}
|
||||
|
||||
123
autogpt_platform/backend/scripts/refresh_claude_token.sh
Executable file
123
autogpt_platform/backend/scripts/refresh_claude_token.sh
Executable file
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env bash
|
||||
# refresh_claude_token.sh — Extract Claude OAuth tokens and update backend/.env
|
||||
#
|
||||
# Works on macOS (keychain), Linux (~/.claude/.credentials.json),
|
||||
# and Windows/WSL (~/.claude/.credentials.json or PowerShell fallback).
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/refresh_claude_token.sh # auto-detect OS
|
||||
# ./scripts/refresh_claude_token.sh --env-file /path/to/.env # custom .env path
|
||||
#
|
||||
# Prerequisite: You must have run `claude login` at least once on the host.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# --- Parse arguments ---
|
||||
ENV_FILE=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--env-file) ENV_FILE="$2"; shift 2 ;;
|
||||
*) echo "Unknown option: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Default .env path: relative to this script's location
|
||||
if [[ -z "$ENV_FILE" ]]; then
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
ENV_FILE="$SCRIPT_DIR/../.env"
|
||||
fi
|
||||
|
||||
# --- Extract tokens by platform ---
|
||||
ACCESS_TOKEN=""
|
||||
REFRESH_TOKEN=""
|
||||
|
||||
extract_from_credentials_file() {
|
||||
local creds_file="$1"
|
||||
if [[ -f "$creds_file" ]]; then
|
||||
ACCESS_TOKEN=$(jq -r '.claudeAiOauth.accessToken // ""' "$creds_file" 2>/dev/null)
|
||||
REFRESH_TOKEN=$(jq -r '.claudeAiOauth.refreshToken // ""' "$creds_file" 2>/dev/null)
|
||||
fi
|
||||
}
|
||||
|
||||
case "$(uname -s)" in
|
||||
Darwin)
|
||||
# macOS: extract from system keychain
|
||||
CREDS_JSON=$(security find-generic-password -s "Claude Code-credentials" -w 2>/dev/null || true)
|
||||
if [[ -n "$CREDS_JSON" ]]; then
|
||||
ACCESS_TOKEN=$(echo "$CREDS_JSON" | jq -r '.claudeAiOauth.accessToken // ""' 2>/dev/null)
|
||||
REFRESH_TOKEN=$(echo "$CREDS_JSON" | jq -r '.claudeAiOauth.refreshToken // ""' 2>/dev/null)
|
||||
else
|
||||
# Fallback to credentials file (e.g. if keychain access denied)
|
||||
extract_from_credentials_file "$HOME/.claude/.credentials.json"
|
||||
fi
|
||||
;;
|
||||
Linux)
|
||||
# Linux (including WSL): read from credentials file
|
||||
extract_from_credentials_file "$HOME/.claude/.credentials.json"
|
||||
;;
|
||||
MINGW*|MSYS*|CYGWIN*)
|
||||
# Windows Git Bash / MSYS2 / Cygwin
|
||||
APPDATA_PATH="${APPDATA:-$USERPROFILE/AppData/Roaming}"
|
||||
extract_from_credentials_file "$APPDATA_PATH/claude/.credentials.json"
|
||||
# Fallback to home dir
|
||||
if [[ -z "$ACCESS_TOKEN" ]]; then
|
||||
extract_from_credentials_file "$HOME/.claude/.credentials.json"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported platform: $(uname -s)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# --- Validate ---
|
||||
if [[ -z "$ACCESS_TOKEN" ]]; then
|
||||
echo "ERROR: Could not extract Claude OAuth token."
|
||||
echo ""
|
||||
echo "Make sure you have run 'claude login' at least once."
|
||||
echo ""
|
||||
echo "Locations checked:"
|
||||
echo " macOS: Keychain ('Claude Code-credentials')"
|
||||
echo " Linux: ~/.claude/.credentials.json"
|
||||
echo " Windows: %APPDATA%/claude/.credentials.json"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found Claude OAuth token: ${ACCESS_TOKEN:0:20}..."
|
||||
[[ -n "$REFRESH_TOKEN" ]] && echo "Found refresh token: ${REFRESH_TOKEN:0:20}..."
|
||||
|
||||
# --- Update .env file ---
|
||||
update_env_var() {
|
||||
local key="$1" value="$2" file="$3"
|
||||
if grep -q "^${key}=" "$file" 2>/dev/null; then
|
||||
# Replace existing value (works on both macOS and Linux sed)
|
||||
if [[ "$(uname -s)" == "Darwin" ]]; then
|
||||
sed -i '' "s|^${key}=.*|${key}=${value}|" "$file"
|
||||
else
|
||||
sed -i "s|^${key}=.*|${key}=${value}|" "$file"
|
||||
fi
|
||||
elif grep -q "^# *${key}=" "$file" 2>/dev/null; then
|
||||
# Uncomment and set
|
||||
if [[ "$(uname -s)" == "Darwin" ]]; then
|
||||
sed -i '' "s|^# *${key}=.*|${key}=${value}|" "$file"
|
||||
else
|
||||
sed -i "s|^# *${key}=.*|${key}=${value}|" "$file"
|
||||
fi
|
||||
else
|
||||
# Append
|
||||
echo "${key}=${value}" >> "$file"
|
||||
fi
|
||||
}
|
||||
|
||||
if [[ ! -f "$ENV_FILE" ]]; then
|
||||
echo "WARNING: $ENV_FILE does not exist, creating it."
|
||||
touch "$ENV_FILE"
|
||||
fi
|
||||
|
||||
update_env_var "CLAUDE_CODE_OAUTH_TOKEN" "$ACCESS_TOKEN" "$ENV_FILE"
|
||||
[[ -n "$REFRESH_TOKEN" ]] && update_env_var "CLAUDE_CODE_REFRESH_TOKEN" "$REFRESH_TOKEN" "$ENV_FILE"
|
||||
update_env_var "CHAT_USE_CLAUDE_CODE_SUBSCRIPTION" "true" "$ENV_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Updated $ENV_FILE with Claude subscription tokens."
|
||||
echo "Run 'docker compose up -d copilot_executor' to apply."
|
||||
@@ -73,7 +73,7 @@
|
||||
"@vercel/analytics": "1.5.0",
|
||||
"@vercel/speed-insights": "1.2.0",
|
||||
"@xyflow/react": "12.9.2",
|
||||
"ai": "6.0.59",
|
||||
"ai": "6.0.134",
|
||||
"boring-avatars": "1.11.2",
|
||||
"canvas-confetti": "1.9.4",
|
||||
"class-variance-authority": "0.7.1",
|
||||
|
||||
68
autogpt_platform/frontend/pnpm-lock.yaml
generated
68
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -142,8 +142,8 @@ importers:
|
||||
specifier: 12.9.2
|
||||
version: 12.9.2(@types/react@18.3.17)(immer@11.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
ai:
|
||||
specifier: 6.0.59
|
||||
version: 6.0.59(zod@3.25.76)
|
||||
specifier: 6.0.134
|
||||
version: 6.0.134(zod@3.25.76)
|
||||
boring-avatars:
|
||||
specifier: 1.11.2
|
||||
version: 1.11.2
|
||||
@@ -448,16 +448,32 @@ packages:
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
|
||||
'@ai-sdk/gateway@3.0.77':
|
||||
resolution: {integrity: sha512-UdwIG2H2YMuntJQ5L+EmED5XiwnlvDT3HOmKfVFxR4Nq/RSLFA/HcchhwfNXHZ5UJjyuL2VO0huLbWSZ9ijemQ==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
|
||||
'@ai-sdk/provider-utils@4.0.10':
|
||||
resolution: {integrity: sha512-VeDAiCH+ZK8Xs4hb9Cw7pHlujWNL52RKe8TExOkrw6Ir1AmfajBZTb9XUdKOZO08RwQElIKA8+Ltm+Gqfo8djQ==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
|
||||
'@ai-sdk/provider-utils@4.0.21':
|
||||
resolution: {integrity: sha512-MtFUYI1/8mgDvRmaBDjbLJPFFrMG777AvSgyIFQtZHIMzm88R/12vYBBpnk7pfiWLFE1DSZzY4WDYzGbKAcmiw==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
|
||||
'@ai-sdk/provider@3.0.5':
|
||||
resolution: {integrity: sha512-2Xmoq6DBJqmSl80U6V9z5jJSJP7ehaJJQMy2iFUqTay06wdCqTnPVBBQbtEL8RCChenL+q5DC5H5WzU3vV3v8w==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@ai-sdk/provider@3.0.8':
|
||||
resolution: {integrity: sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
'@ai-sdk/react@3.0.61':
|
||||
resolution: {integrity: sha512-vCjZBnY2+TawFBXamSKt6elAt9n1MXMfcjSd9DSgT9peCJN27qNGVSXgaGNh/B3cUgeOktFfhB2GVmIqOjvmLQ==}
|
||||
engines: {node: '>=18'}
|
||||
@@ -4053,6 +4069,12 @@ packages:
|
||||
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
|
||||
engines: {node: '>= 14'}
|
||||
|
||||
ai@6.0.134:
|
||||
resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==}
|
||||
engines: {node: '>=18'}
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4.1.8
|
||||
|
||||
ai@6.0.59:
|
||||
resolution: {integrity: sha512-9SfCvcr4kVk4t8ZzIuyHpuL1hFYKsYMQfBSbBq3dipXPa+MphARvI8wHEjNaRqYl3JOsJbWxEBIMqHL0L92mUA==}
|
||||
engines: {node: '>=18'}
|
||||
@@ -8718,6 +8740,13 @@ snapshots:
|
||||
'@vercel/oidc': 3.1.0
|
||||
zod: 3.25.76
|
||||
|
||||
'@ai-sdk/gateway@3.0.77(zod@3.25.76)':
|
||||
dependencies:
|
||||
'@ai-sdk/provider': 3.0.8
|
||||
'@ai-sdk/provider-utils': 4.0.21(zod@3.25.76)
|
||||
'@vercel/oidc': 3.1.0
|
||||
zod: 3.25.76
|
||||
|
||||
'@ai-sdk/provider-utils@4.0.10(zod@3.25.76)':
|
||||
dependencies:
|
||||
'@ai-sdk/provider': 3.0.5
|
||||
@@ -8725,10 +8754,21 @@ snapshots:
|
||||
eventsource-parser: 3.0.6
|
||||
zod: 3.25.76
|
||||
|
||||
'@ai-sdk/provider-utils@4.0.21(zod@3.25.76)':
|
||||
dependencies:
|
||||
'@ai-sdk/provider': 3.0.8
|
||||
'@standard-schema/spec': 1.1.0
|
||||
eventsource-parser: 3.0.6
|
||||
zod: 3.25.76
|
||||
|
||||
'@ai-sdk/provider@3.0.5':
|
||||
dependencies:
|
||||
json-schema: 0.4.0
|
||||
|
||||
'@ai-sdk/provider@3.0.8':
|
||||
dependencies:
|
||||
json-schema: 0.4.0
|
||||
|
||||
'@ai-sdk/react@3.0.61(react@18.3.1)(zod@3.25.76)':
|
||||
dependencies:
|
||||
'@ai-sdk/provider-utils': 4.0.10(zod@3.25.76)
|
||||
@@ -12798,6 +12838,14 @@ snapshots:
|
||||
agent-base@7.1.4:
|
||||
optional: true
|
||||
|
||||
ai@6.0.134(zod@3.25.76):
|
||||
dependencies:
|
||||
'@ai-sdk/gateway': 3.0.77(zod@3.25.76)
|
||||
'@ai-sdk/provider': 3.0.8
|
||||
'@ai-sdk/provider-utils': 4.0.21(zod@3.25.76)
|
||||
'@opentelemetry/api': 1.9.0
|
||||
zod: 3.25.76
|
||||
|
||||
ai@6.0.59(zod@3.25.76):
|
||||
dependencies:
|
||||
'@ai-sdk/gateway': 3.0.27(zod@3.25.76)
|
||||
@@ -14066,8 +14114,8 @@ snapshots:
|
||||
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
|
||||
eslint-plugin-react: 7.37.5(eslint@8.57.1)
|
||||
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
|
||||
@@ -14086,7 +14134,7 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@nolyfill/is-core-module': 1.0.39
|
||||
debug: 4.4.3
|
||||
@@ -14097,22 +14145,22 @@ snapshots:
|
||||
tinyglobby: 0.2.15
|
||||
unrs-resolver: 1.11.1
|
||||
optionalDependencies:
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
||||
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||
dependencies:
|
||||
debug: 3.2.7
|
||||
optionalDependencies:
|
||||
'@typescript-eslint/parser': 8.52.0(eslint@8.57.1)(typescript@5.9.3)
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
|
||||
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
|
||||
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@rtsao/scc': 1.1.0
|
||||
array-includes: 3.1.9
|
||||
@@ -14123,7 +14171,7 @@ snapshots:
|
||||
doctrine: 2.1.0
|
||||
eslint: 8.57.1
|
||||
eslint-import-resolver-node: 0.3.9
|
||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
|
||||
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.52.0(eslint@8.57.1)(typescript@5.9.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
|
||||
hasown: 2.0.2
|
||||
is-core-module: 2.16.1
|
||||
is-glob: 4.0.3
|
||||
|
||||
@@ -15,46 +15,11 @@ import { useCopilotUIStore } from "./store";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useCopilotNotifications } from "./useCopilotNotifications";
|
||||
import { useCopilotStream } from "./useCopilotStream";
|
||||
import { useWorkflowImportAutoSubmit } from "./useWorkflowImportAutoSubmit";
|
||||
|
||||
const TITLE_POLL_INTERVAL_MS = 2_000;
|
||||
const TITLE_POLL_MAX_ATTEMPTS = 5;
|
||||
|
||||
/**
|
||||
* Extract a prompt from the URL hash fragment.
|
||||
* Supports: /copilot#prompt=URL-encoded-text
|
||||
* Optionally auto-submits if ?autosubmit=true is in the query string.
|
||||
* Returns null if no prompt is present.
|
||||
*/
|
||||
function extractPromptFromUrl(): {
|
||||
prompt: string;
|
||||
autosubmit: boolean;
|
||||
} | null {
|
||||
if (typeof window === "undefined") return null;
|
||||
|
||||
const hash = window.location.hash;
|
||||
if (!hash) return null;
|
||||
|
||||
const hashParams = new URLSearchParams(hash.slice(1));
|
||||
const prompt = hashParams.get("prompt");
|
||||
|
||||
if (!prompt || !prompt.trim()) return null;
|
||||
|
||||
const searchParams = new URLSearchParams(window.location.search);
|
||||
const autosubmit = searchParams.get("autosubmit") === "true";
|
||||
|
||||
// Clean up hash + autosubmit param only (preserve other query params)
|
||||
const cleanURL = new URL(window.location.href);
|
||||
cleanURL.hash = "";
|
||||
cleanURL.searchParams.delete("autosubmit");
|
||||
window.history.replaceState(
|
||||
null,
|
||||
"",
|
||||
`${cleanURL.pathname}${cleanURL.search}`,
|
||||
);
|
||||
|
||||
return { prompt: prompt.trim(), autosubmit };
|
||||
}
|
||||
|
||||
interface UploadedFile {
|
||||
file_id: string;
|
||||
name: string;
|
||||
@@ -130,16 +95,23 @@ export function useCopilotPage() {
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const pendingFilesRef = useRef<File[]>([]);
|
||||
// Pre-built file parts from workflow import (already uploaded, skip re-upload)
|
||||
const pendingFilePartsRef = useRef<FileUIPart[]>([]);
|
||||
|
||||
// --- Send pending message after session creation ---
|
||||
useEffect(() => {
|
||||
if (!sessionId || pendingMessage === null) return;
|
||||
const msg = pendingMessage;
|
||||
const files = pendingFilesRef.current;
|
||||
const prebuiltParts = pendingFilePartsRef.current;
|
||||
setPendingMessage(null);
|
||||
pendingFilesRef.current = [];
|
||||
pendingFilePartsRef.current = [];
|
||||
|
||||
if (files.length > 0) {
|
||||
if (prebuiltParts.length > 0) {
|
||||
// File already uploaded (e.g. workflow import) — send directly
|
||||
sendMessage({ text: msg, files: prebuiltParts });
|
||||
} else if (files.length > 0) {
|
||||
setIsUploadingFiles(true);
|
||||
void uploadFiles(files, sessionId)
|
||||
.then((uploaded) => {
|
||||
@@ -164,26 +136,11 @@ export function useCopilotPage() {
|
||||
}, [sessionId, pendingMessage, sendMessage]);
|
||||
|
||||
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
|
||||
const { setInitialPrompt } = useCopilotUIStore();
|
||||
const hasProcessedUrlPrompt = useRef(false);
|
||||
useEffect(() => {
|
||||
if (hasProcessedUrlPrompt.current) return;
|
||||
|
||||
const urlPrompt = extractPromptFromUrl();
|
||||
if (!urlPrompt) return;
|
||||
|
||||
hasProcessedUrlPrompt.current = true;
|
||||
|
||||
if (urlPrompt.autosubmit) {
|
||||
setPendingMessage(urlPrompt.prompt);
|
||||
void createSession().catch(() => {
|
||||
setPendingMessage(null);
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
});
|
||||
} else {
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
}
|
||||
}, [createSession, setInitialPrompt]);
|
||||
useWorkflowImportAutoSubmit({
|
||||
createSession,
|
||||
setPendingMessage,
|
||||
pendingFilePartsRef,
|
||||
});
|
||||
|
||||
async function uploadFiles(
|
||||
files: File[],
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import type { FileUIPart } from "ai";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
|
||||
/**
|
||||
* Extract a prompt from the URL hash fragment.
|
||||
* Supports: /copilot#prompt=URL-encoded-text
|
||||
* Optionally auto-submits if ?autosubmit=true is in the query string.
|
||||
* Returns null if no prompt is present.
|
||||
*/
|
||||
function extractPromptFromUrl(): {
|
||||
prompt: string;
|
||||
autosubmit: boolean;
|
||||
filePart?: FileUIPart;
|
||||
} | null {
|
||||
if (typeof window === "undefined") return null;
|
||||
|
||||
const searchParams = new URLSearchParams(window.location.search);
|
||||
const autosubmit = searchParams.get("autosubmit") === "true";
|
||||
|
||||
// Check sessionStorage first (used by workflow import for large prompts)
|
||||
const storedPrompt = sessionStorage.getItem("importWorkflowPrompt");
|
||||
if (storedPrompt) {
|
||||
sessionStorage.removeItem("importWorkflowPrompt");
|
||||
|
||||
// Check for a pre-uploaded workflow file attached to this import
|
||||
let filePart: FileUIPart | undefined;
|
||||
const storedFile = sessionStorage.getItem("importWorkflowFile");
|
||||
if (storedFile) {
|
||||
sessionStorage.removeItem("importWorkflowFile");
|
||||
try {
|
||||
const { fileId, fileName, mimeType } = JSON.parse(storedFile);
|
||||
// Validate fileId is a UUID to prevent path traversal
|
||||
const UUID_RE =
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
|
||||
if (typeof fileId === "string" && UUID_RE.test(fileId)) {
|
||||
filePart = {
|
||||
type: "file",
|
||||
mediaType: mimeType ?? "application/json",
|
||||
filename: fileName ?? "workflow.json",
|
||||
url: `/api/proxy/api/workspace/files/${fileId}/download`,
|
||||
};
|
||||
}
|
||||
} catch {
|
||||
// ignore malformed stored data
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up query params
|
||||
const cleanURL = new URL(window.location.href);
|
||||
cleanURL.searchParams.delete("autosubmit");
|
||||
cleanURL.searchParams.delete("source");
|
||||
window.history.replaceState(
|
||||
null,
|
||||
"",
|
||||
`${cleanURL.pathname}${cleanURL.search}`,
|
||||
);
|
||||
return { prompt: storedPrompt.trim(), autosubmit, filePart };
|
||||
}
|
||||
|
||||
// Fall back to URL hash (e.g. /copilot#prompt=...)
|
||||
const hash = window.location.hash;
|
||||
if (!hash) return null;
|
||||
|
||||
const hashParams = new URLSearchParams(hash.slice(1));
|
||||
const prompt = hashParams.get("prompt");
|
||||
|
||||
if (!prompt || !prompt.trim()) return null;
|
||||
|
||||
// Clean up hash + autosubmit param only (preserve other query params)
|
||||
const cleanURL = new URL(window.location.href);
|
||||
cleanURL.hash = "";
|
||||
cleanURL.searchParams.delete("autosubmit");
|
||||
window.history.replaceState(
|
||||
null,
|
||||
"",
|
||||
`${cleanURL.pathname}${cleanURL.search}`,
|
||||
);
|
||||
|
||||
return { prompt: prompt.trim(), autosubmit };
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook that checks for workflow import data in sessionStorage / URL on mount,
|
||||
* and auto-submits a new CoPilot session when `autosubmit=true`.
|
||||
*
|
||||
* Extracted from useCopilotPage to keep that hook focused on page-level concerns.
|
||||
*/
|
||||
export function useWorkflowImportAutoSubmit({
|
||||
createSession,
|
||||
setPendingMessage,
|
||||
pendingFilePartsRef,
|
||||
}: {
|
||||
createSession: () => Promise<string | undefined>;
|
||||
setPendingMessage: (msg: string | null) => void;
|
||||
pendingFilePartsRef: React.MutableRefObject<FileUIPart[]>;
|
||||
}) {
|
||||
const { setInitialPrompt } = useCopilotUIStore();
|
||||
const hasProcessedUrlPrompt = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (hasProcessedUrlPrompt.current) return;
|
||||
|
||||
const urlPrompt = extractPromptFromUrl();
|
||||
if (!urlPrompt) return;
|
||||
|
||||
hasProcessedUrlPrompt.current = true;
|
||||
|
||||
if (urlPrompt.autosubmit) {
|
||||
if (urlPrompt.filePart) {
|
||||
pendingFilePartsRef.current = [urlPrompt.filePart];
|
||||
}
|
||||
setPendingMessage(urlPrompt.prompt);
|
||||
void createSession().catch(() => {
|
||||
setPendingMessage(null);
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
});
|
||||
} else {
|
||||
setInitialPrompt(urlPrompt.prompt);
|
||||
}
|
||||
}, [createSession, setInitialPrompt, setPendingMessage, pendingFilePartsRef]);
|
||||
}
|
||||
@@ -169,7 +169,7 @@ function renderMarkdown(
|
||||
[remarkMath, { singleDollarTextMath: false }], // Math support for LaTeX
|
||||
]}
|
||||
rehypePlugins={[
|
||||
rehypeKatex, // Render math with KaTeX
|
||||
[rehypeKatex, { strict: false }], // Render math with KaTeX
|
||||
rehypeHighlight, // Syntax highlighting for code blocks
|
||||
rehypeSlug, // Add IDs to headings
|
||||
[rehypeAutolinkHeadings, { behavior: "wrap" }], // Make headings clickable
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import LibraryImportDialog from "../LibraryImportDialog/LibraryImportDialog";
|
||||
import { LibrarySearchBar } from "../LibrarySearchBar/LibrarySearchBar";
|
||||
import LibraryUploadAgentDialog from "../LibraryUploadAgentDialog/LibraryUploadAgentDialog";
|
||||
|
||||
interface Props {
|
||||
setSearchTerm: (value: string) => void;
|
||||
@@ -10,13 +10,13 @@ export function LibraryActionHeader({ setSearchTerm }: Props) {
|
||||
<>
|
||||
<div className="mb-[32px] hidden items-center justify-center gap-4 md:flex">
|
||||
<LibrarySearchBar setSearchTerm={setSearchTerm} />
|
||||
<LibraryUploadAgentDialog />
|
||||
<LibraryImportDialog />
|
||||
</div>
|
||||
|
||||
{/* Mobile and tablet */}
|
||||
<div className="flex flex-col gap-4 p-4 pt-[52px] md:hidden">
|
||||
<div className="flex w-full justify-between">
|
||||
<LibraryUploadAgentDialog />
|
||||
<div className="flex w-full justify-between gap-2">
|
||||
<LibraryImportDialog />
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-center">
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"use client";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import {
|
||||
TabsLine,
|
||||
TabsLineList,
|
||||
TabsLineTrigger,
|
||||
} from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { UploadSimpleIcon } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
import { useLibraryUploadAgentDialog } from "../LibraryUploadAgentDialog/useLibraryUploadAgentDialog";
|
||||
import AgentUploadTab from "./components/AgentUploadTab/AgentUploadTab";
|
||||
import ExternalWorkflowTab from "./components/ExternalWorkflowTab/ExternalWorkflowTab";
|
||||
import { useExternalWorkflowTab } from "./components/ExternalWorkflowTab/useExternalWorkflowTab";
|
||||
|
||||
export default function LibraryImportDialog() {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const importWorkflow = useExternalWorkflowTab();
|
||||
|
||||
function handleClose() {
|
||||
setIsOpen(false);
|
||||
importWorkflow.setFileValue("");
|
||||
importWorkflow.setUrlValue("");
|
||||
}
|
||||
|
||||
const upload = useLibraryUploadAgentDialog({ onSuccess: handleClose });
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
title="Import"
|
||||
styling={{ maxWidth: "32rem" }}
|
||||
controlled={{
|
||||
isOpen,
|
||||
set: setIsOpen,
|
||||
}}
|
||||
onClose={handleClose}
|
||||
>
|
||||
<Dialog.Trigger>
|
||||
<Button
|
||||
data-testid="import-button"
|
||||
variant="primary"
|
||||
className="h-[2.78rem] w-full md:w-[10rem]"
|
||||
size="small"
|
||||
>
|
||||
<UploadSimpleIcon width={18} height={18} />
|
||||
<span>Import</span>
|
||||
</Button>
|
||||
</Dialog.Trigger>
|
||||
<Dialog.Content>
|
||||
<TabsLine defaultValue="agent">
|
||||
<TabsLineList>
|
||||
<TabsLineTrigger value="agent">AutoGPT agent</TabsLineTrigger>
|
||||
<TabsLineTrigger value="platform">Another platform</TabsLineTrigger>
|
||||
</TabsLineList>
|
||||
|
||||
{/* Tab: Import from any platform (file upload + n8n URL) */}
|
||||
<ExternalWorkflowTab importWorkflow={importWorkflow} />
|
||||
|
||||
{/* Tab: Upload AutoGPT agent JSON */}
|
||||
<AgentUploadTab upload={upload} />
|
||||
</TabsLine>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
"use client";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { FileInput } from "@/components/atoms/FileInput/FileInput";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormMessage,
|
||||
} from "@/components/molecules/Form/Form";
|
||||
import { TabsLineContent } from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { useLibraryUploadAgentDialog } from "../../../LibraryUploadAgentDialog/useLibraryUploadAgentDialog";
|
||||
|
||||
type AgentUploadTabProps = {
|
||||
upload: ReturnType<typeof useLibraryUploadAgentDialog>;
|
||||
};
|
||||
|
||||
export default function AgentUploadTab({ upload }: AgentUploadTabProps) {
|
||||
return (
|
||||
<TabsLineContent value="agent">
|
||||
<p className="mb-4 text-sm text-neutral-500">
|
||||
Upload a previously exported AutoGPT agent file (.json).
|
||||
</p>
|
||||
<Form
|
||||
form={upload.form}
|
||||
onSubmit={upload.onSubmit}
|
||||
className="flex flex-col justify-center gap-0 px-1"
|
||||
>
|
||||
<FormField
|
||||
control={upload.form.control}
|
||||
name="agentName"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormControl>
|
||||
<Input
|
||||
{...field}
|
||||
id={field.name}
|
||||
label="Agent name"
|
||||
className="w-full rounded-[10px]"
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={upload.form.control}
|
||||
name="agentDescription"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormControl>
|
||||
<Input
|
||||
{...field}
|
||||
id={field.name}
|
||||
label="Agent description"
|
||||
type="textarea"
|
||||
className="w-full rounded-[10px]"
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={upload.form.control}
|
||||
name="agentFile"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormControl>
|
||||
<FileInput
|
||||
mode="base64"
|
||||
value={field.value}
|
||||
onChange={field.onChange}
|
||||
accept=".json,application/json"
|
||||
placeholder="Agent file"
|
||||
maxFileSize={10 * 1024 * 1024}
|
||||
showStorageNote={false}
|
||||
className="mb-8 mt-4"
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
type="submit"
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
disabled={!upload.agentObject || upload.isUploading}
|
||||
>
|
||||
{upload.isUploading ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<LoadingSpinner size="small" className="text-white" />
|
||||
<span>Uploading...</span>
|
||||
</div>
|
||||
) : (
|
||||
"Upload"
|
||||
)}
|
||||
</Button>
|
||||
</Form>
|
||||
</TabsLineContent>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
"use client";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { FileInput } from "@/components/atoms/FileInput/FileInput";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { TabsLineContent } from "@/components/molecules/TabsLine/TabsLine";
|
||||
import { useExternalWorkflowTab } from "./useExternalWorkflowTab";
|
||||
|
||||
const N8N_EXAMPLES = [
|
||||
{ label: "Build Your First AI Agent", url: "https://n8n.io/workflows/6270" },
|
||||
{ label: "Interactive AI Chat Agent", url: "https://n8n.io/workflows/5819" },
|
||||
];
|
||||
|
||||
type ExternalWorkflowTabProps = {
|
||||
importWorkflow: ReturnType<typeof useExternalWorkflowTab>;
|
||||
};
|
||||
|
||||
export default function ExternalWorkflowTab({
|
||||
importWorkflow,
|
||||
}: ExternalWorkflowTabProps) {
|
||||
return (
|
||||
<TabsLineContent value="platform">
|
||||
<p className="mb-4 text-sm text-neutral-500">
|
||||
Upload a workflow exported from n8n, Make.com, Zapier, or any other
|
||||
platform. AutoPilot will convert it into an AutoGPT agent for you.
|
||||
</p>
|
||||
<FileInput
|
||||
mode="base64"
|
||||
value={importWorkflow.fileValue}
|
||||
onChange={importWorkflow.setFileValue}
|
||||
accept=".json,application/json"
|
||||
placeholder="Workflow file (n8n, Make.com, Zapier, ...)"
|
||||
maxFileSize={10 * 1024 * 1024}
|
||||
showStorageNote={false}
|
||||
className="mb-4 mt-2"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
disabled={!importWorkflow.fileValue || importWorkflow.isSubmitting}
|
||||
onClick={() => importWorkflow.submitWithMode("file")}
|
||||
>
|
||||
{importWorkflow.submittingMode === "file" ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<LoadingSpinner size="small" className="text-white" />
|
||||
<span>Importing...</span>
|
||||
</div>
|
||||
) : (
|
||||
"Import to AutoPilot"
|
||||
)}
|
||||
</Button>
|
||||
|
||||
<div className="my-5 flex items-center gap-3">
|
||||
<div className="h-px flex-1 bg-neutral-200" />
|
||||
<span className="text-xs text-neutral-400">or import from URL</span>
|
||||
<div className="h-px flex-1 bg-neutral-200" />
|
||||
</div>
|
||||
|
||||
<div className="mb-3 flex flex-wrap gap-2">
|
||||
{N8N_EXAMPLES.map((p) => (
|
||||
<button
|
||||
key={p.label}
|
||||
type="button"
|
||||
disabled={importWorkflow.isSubmitting}
|
||||
onClick={() => importWorkflow.setUrlValue(p.url)}
|
||||
className="rounded-full border border-neutral-200 px-3 py-1 text-xs text-neutral-600 hover:border-purple-400 hover:text-purple-600 disabled:opacity-50"
|
||||
>
|
||||
{p.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
<Input
|
||||
id="template-url"
|
||||
value={importWorkflow.urlValue}
|
||||
onChange={(e) => importWorkflow.setUrlValue(e.target.value)}
|
||||
label="Workflow URL"
|
||||
placeholder="https://n8n.io/workflows/1234"
|
||||
className="mb-4 w-full rounded-[10px]"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
disabled={!importWorkflow.urlValue || importWorkflow.isSubmitting}
|
||||
onClick={() => importWorkflow.submitWithMode("url")}
|
||||
>
|
||||
{importWorkflow.submittingMode === "url" ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<LoadingSpinner size="small" className="text-white" />
|
||||
<span>Importing...</span>
|
||||
</div>
|
||||
) : (
|
||||
"Import from URL"
|
||||
)}
|
||||
</Button>
|
||||
</TabsLineContent>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
"use server";
|
||||
|
||||
/**
|
||||
* Regex to extract the numeric template ID from various n8n URL formats:
|
||||
* - https://n8n.io/workflows/1234
|
||||
* - https://n8n.io/workflows/1234-some-slug
|
||||
* - https://api.n8n.io/api/templates/workflows/1234
|
||||
*/
|
||||
const N8N_TEMPLATE_ID_RE = /n8n\.io\/(?:api\/templates\/)?workflows\/(\d+)/i;
|
||||
|
||||
/** Hardcoded n8n templates API base — the only URL we ever fetch. */
|
||||
const N8N_TEMPLATES_API = "https://api.n8n.io/api/templates/workflows";
|
||||
|
||||
/** Max response body size (10 MB) to prevent memory exhaustion. */
|
||||
const MAX_RESPONSE_BYTES = 10 * 1024 * 1024;
|
||||
|
||||
export type FetchWorkflowResult =
|
||||
| { ok: true; json: string }
|
||||
| { ok: false; error: string };
|
||||
|
||||
/**
|
||||
* Server action that fetches a workflow JSON from an n8n template URL.
|
||||
* Runs server-side so there are no CORS restrictions.
|
||||
*
|
||||
* Returns a result object instead of throwing because Next.js
|
||||
* server actions do not propagate error messages to the client.
|
||||
*
|
||||
* Only n8n.io workflow URLs are accepted. The template ID is extracted
|
||||
* and used to call the hardcoded n8n API — the user-supplied URL is
|
||||
* never passed to fetch() directly (SSRF prevention).
|
||||
*/
|
||||
export async function fetchWorkflowFromUrl(
|
||||
url: string,
|
||||
): Promise<FetchWorkflowResult> {
|
||||
const match = url.match(N8N_TEMPLATE_ID_RE);
|
||||
if (!match) {
|
||||
return {
|
||||
ok: false,
|
||||
error:
|
||||
"Invalid or unsupported URL. " +
|
||||
"URL import is supported for n8n.io workflow templates " +
|
||||
"(e.g. https://n8n.io/workflows/1234). " +
|
||||
"For other platforms, use file upload.",
|
||||
};
|
||||
}
|
||||
|
||||
const templateId = match[1]; // purely numeric, safe to interpolate
|
||||
|
||||
try {
|
||||
const json = await fetchN8nWorkflow(templateId);
|
||||
return { ok: true, json };
|
||||
} catch (err) {
|
||||
return {
|
||||
ok: false,
|
||||
error: err instanceof Error ? err.message : "Failed to fetch workflow.",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchN8nWorkflow(templateId: string): Promise<string> {
|
||||
// Only ever fetch from the hardcoded API base + numeric ID.
|
||||
// parseInt + toString round-trips to guarantee the value is purely numeric,
|
||||
// preventing any path-traversal or SSRF via the interpolated segment.
|
||||
const safeId = parseInt(templateId, 10);
|
||||
if (!Number.isFinite(safeId) || safeId <= 0) {
|
||||
throw new Error("Invalid template ID");
|
||||
}
|
||||
const res = await fetch(`${N8N_TEMPLATES_API}/${safeId.toString()}`);
|
||||
if (!res.ok) throw new Error(`n8n template not found (${res.status})`);
|
||||
|
||||
const contentLength = res.headers.get("content-length");
|
||||
if (contentLength && parseInt(contentLength, 10) > MAX_RESPONSE_BYTES) {
|
||||
throw new Error("Response too large.");
|
||||
}
|
||||
|
||||
const text = await res.text();
|
||||
if (text.length > MAX_RESPONSE_BYTES) throw new Error("Response too large.");
|
||||
|
||||
const data = JSON.parse(text);
|
||||
const template = data?.workflow ?? data;
|
||||
const workflow = template?.workflow ?? template;
|
||||
if (!workflow?.nodes) throw new Error("Unexpected n8n API response format");
|
||||
if (!workflow.name) workflow.name = template?.name ?? data?.name ?? "";
|
||||
return JSON.stringify(workflow);
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { uploadFileDirect } from "@/lib/direct-upload";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import { fetchWorkflowFromUrl } from "./fetchWorkflowFromUrl";
|
||||
|
||||
function decodeBase64Json(dataUrl: string): string {
|
||||
const match = dataUrl.match(/^data:[^;]+;base64,(.+)$/);
|
||||
if (!match) throw new Error("Could not read the uploaded file.");
|
||||
const binary = atob(match[1]);
|
||||
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
|
||||
const json = new TextDecoder().decode(bytes);
|
||||
JSON.parse(json); // validate — throws SyntaxError if invalid
|
||||
return json;
|
||||
}
|
||||
|
||||
async function uploadJsonAsFile(
|
||||
jsonString: string,
|
||||
): Promise<{ fileId: string; fileName: string; mimeType: string }> {
|
||||
const file = new File(
|
||||
[new Blob([jsonString], { type: "application/json" })],
|
||||
`workflow-${crypto.randomUUID()}.json`,
|
||||
{ type: "application/json" },
|
||||
);
|
||||
const uploaded = await uploadFileDirect(file);
|
||||
return {
|
||||
fileId: uploaded.file_id,
|
||||
fileName: uploaded.name,
|
||||
mimeType: uploaded.mime_type,
|
||||
};
|
||||
}
|
||||
|
||||
function storeAndRedirect(
|
||||
fileInfo: { fileId: string; fileName: string; mimeType: string },
|
||||
router: ReturnType<typeof useRouter>,
|
||||
) {
|
||||
sessionStorage.setItem(
|
||||
"importWorkflowPrompt",
|
||||
"Import this workflow and recreate it as an AutoGPT agent",
|
||||
);
|
||||
sessionStorage.setItem("importWorkflowFile", JSON.stringify(fileInfo));
|
||||
router.push("/copilot?source=import&autosubmit=true");
|
||||
}
|
||||
|
||||
export function useExternalWorkflowTab() {
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const [fileValue, setFileValue] = useState("");
|
||||
const [urlValue, setUrlValue] = useState("");
|
||||
const [submittingMode, setSubmittingMode] = useState<"url" | "file" | null>(
|
||||
null,
|
||||
);
|
||||
const isSubmitting = submittingMode !== null;
|
||||
|
||||
async function submitWithMode(mode: "url" | "file") {
|
||||
setSubmittingMode(mode);
|
||||
try {
|
||||
const jsonString = await resolveJson(mode);
|
||||
if (!jsonString) return;
|
||||
storeAndRedirect(await uploadJsonAsFile(jsonString), router);
|
||||
} catch (err) {
|
||||
toast({
|
||||
title: "Upload failed",
|
||||
description:
|
||||
err instanceof Error ? err.message : "Could not upload the file.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} finally {
|
||||
setSubmittingMode(null);
|
||||
}
|
||||
}
|
||||
|
||||
async function resolveJson(mode: "url" | "file"): Promise<string | null> {
|
||||
if (mode === "url") {
|
||||
const result = await fetchWorkflowFromUrl(urlValue);
|
||||
if (!result.ok) {
|
||||
toast({
|
||||
title: "Could not fetch workflow",
|
||||
description: result.error,
|
||||
variant: "destructive",
|
||||
});
|
||||
return null;
|
||||
}
|
||||
setUrlValue("");
|
||||
return result.json;
|
||||
}
|
||||
|
||||
try {
|
||||
const json = decodeBase64Json(fileValue);
|
||||
setFileValue("");
|
||||
return json;
|
||||
} catch (err) {
|
||||
const isParseError = err instanceof SyntaxError;
|
||||
toast({
|
||||
title: isParseError ? "Invalid JSON" : "Invalid file",
|
||||
description: isParseError
|
||||
? "The uploaded file is not valid JSON."
|
||||
: "Could not read the uploaded file.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
submitWithMode,
|
||||
fileValue,
|
||||
setFileValue,
|
||||
urlValue,
|
||||
setUrlValue,
|
||||
isSubmitting,
|
||||
submittingMode,
|
||||
};
|
||||
}
|
||||
@@ -9,7 +9,9 @@ import { useForm } from "react-hook-form";
|
||||
import { z } from "zod";
|
||||
import { uploadAgentFormSchema } from "./LibraryUploadAgentDialog";
|
||||
|
||||
export function useLibraryUploadAgentDialog() {
|
||||
export function useLibraryUploadAgentDialog(options?: {
|
||||
onSuccess?: () => void;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { toast } = useToast();
|
||||
const [agentObject, setAgentObject] = useState<Graph | null>(null);
|
||||
@@ -19,6 +21,7 @@ export function useLibraryUploadAgentDialog() {
|
||||
mutation: {
|
||||
onSuccess: ({ data }) => {
|
||||
setIsOpen(false);
|
||||
options?.onSuccess?.();
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "Agent uploaded successfully",
|
||||
@@ -114,7 +117,7 @@ export function useLibraryUploadAgentDialog() {
|
||||
}
|
||||
}, [agentFileValue, form, toast]);
|
||||
|
||||
const onSubmit = async (values: z.infer<typeof uploadAgentFormSchema>) => {
|
||||
async function onSubmit(values: z.infer<typeof uploadAgentFormSchema>) {
|
||||
if (!agentObject) {
|
||||
form.setError("root", { message: "No Agent object to save" });
|
||||
return;
|
||||
@@ -133,7 +136,7 @@ export function useLibraryUploadAgentDialog() {
|
||||
source: "upload",
|
||||
},
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
onSubmit,
|
||||
|
||||
@@ -14,9 +14,9 @@ import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { formatTimeAgo } from "@/lib/utils/time";
|
||||
import Link from "next/link";
|
||||
import { FileArrowDownIcon, PlusIcon } from "@phosphor-icons/react";
|
||||
import { PlusIcon } from "@phosphor-icons/react";
|
||||
import { User } from "@supabase/supabase-js";
|
||||
import Link from "next/link";
|
||||
import { useAgentInfo } from "./useAgentInfo";
|
||||
|
||||
interface AgentInfoProps {
|
||||
@@ -180,52 +180,57 @@ export const AgentInfo = ({
|
||||
{shortDescription}
|
||||
</div>
|
||||
|
||||
{/* Buttons + Runs */}
|
||||
<div className="mt-6 flex w-full items-center justify-between lg:mt-8">
|
||||
<div className="flex gap-3">
|
||||
{user && (
|
||||
<Button
|
||||
variant="primary"
|
||||
className="group/add min-w-36 border-violet-600 bg-violet-600 transition-shadow duration-300 hover:border-violet-500 hover:bg-violet-500 hover:shadow-[0_0_20px_rgba(139,92,246,0.4)]"
|
||||
data-testid="agent-add-library-button"
|
||||
disabled={isAddingAgentToLibrary}
|
||||
loading={isAddingAgentToLibrary}
|
||||
leftIcon={
|
||||
!isAddingAgentToLibrary && !isAgentAddedToLibrary ? (
|
||||
<PlusIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className="transition-transform duration-300 group-hover/add:rotate-90 group-hover/add:scale-125"
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
onClick={() =>
|
||||
handleLibraryAction({
|
||||
isAddingAgentFirstTime: !isAgentAddedToLibrary,
|
||||
})
|
||||
}
|
||||
>
|
||||
{isAddingAgentToLibrary
|
||||
? "Adding..."
|
||||
: isAgentAddedToLibrary
|
||||
? "See runs"
|
||||
: "Add to library"}
|
||||
</Button>
|
||||
)}
|
||||
{/* Buttons */}
|
||||
<div className="mt-6 flex w-full items-center lg:mt-8">
|
||||
{user && (
|
||||
<Button
|
||||
variant="primary"
|
||||
className="group/add min-w-36 border-violet-600 bg-violet-600 transition-shadow duration-300 hover:border-violet-500 hover:bg-violet-500 hover:shadow-[0_0_20px_rgba(139,92,246,0.4)]"
|
||||
data-testid="agent-add-library-button"
|
||||
disabled={isAddingAgentToLibrary}
|
||||
loading={isAddingAgentToLibrary}
|
||||
leftIcon={
|
||||
!isAddingAgentToLibrary && !isAgentAddedToLibrary ? (
|
||||
<PlusIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className="transition-transform duration-300 group-hover/add:rotate-90 group-hover/add:scale-125"
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
onClick={() =>
|
||||
handleLibraryAction({
|
||||
isAddingAgentFirstTime: !isAgentAddedToLibrary,
|
||||
})
|
||||
}
|
||||
>
|
||||
{isAddingAgentToLibrary
|
||||
? "Adding..."
|
||||
: isAgentAddedToLibrary
|
||||
? "See runs"
|
||||
: "Add to library"}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Download */}
|
||||
<div className="mt-3 flex w-full items-center justify-between gap-2">
|
||||
<div className="flex items-center gap-0">
|
||||
<Text variant="body" className="text-neutral-500">
|
||||
Want to use this agent locally?
|
||||
</Text>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
loading={isDownloadingAgent}
|
||||
onClick={() => handleDownload(agentId, name)}
|
||||
data-testid="agent-download-button"
|
||||
className="underline"
|
||||
>
|
||||
{!isDownloadingAgent && <FileArrowDownIcon size={18} />}
|
||||
{isDownloadingAgent ? "Downloading..." : "Download"}
|
||||
{isDownloadingAgent ? "Downloading..." : "Download here"}
|
||||
</Button>
|
||||
</div>
|
||||
<Text
|
||||
variant="small"
|
||||
className="mr-4 hidden whitespace-nowrap text-zinc-500 lg:block"
|
||||
>
|
||||
<Text variant="body" className="shrink-0 whitespace-nowrap">
|
||||
{runs === 0
|
||||
? "No runs"
|
||||
: `${runs.toLocaleString()} run${runs > 1 ? "s" : ""}`}
|
||||
|
||||
@@ -113,7 +113,7 @@ export function StoreCard({
|
||||
|
||||
{/* Third Section: Description */}
|
||||
<div className="mt-2.5 flex w-full flex-col">
|
||||
<Text variant="body" className="line-clamp-3 leading-normal">
|
||||
<Text variant="body" className="line-clamp-2 leading-normal">
|
||||
{description}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -81,7 +81,7 @@ export function CredentialsInput({
|
||||
isHostScopedCredentialsModalOpen,
|
||||
isCredentialTypeSelectorOpen,
|
||||
isOAuth2FlowInProgress,
|
||||
oAuthPopupController,
|
||||
cancelOAuthFlow,
|
||||
actionButtonText,
|
||||
setAPICredentialsModalOpen,
|
||||
setUserPasswordCredentialsModalOpen,
|
||||
@@ -158,7 +158,7 @@ export function CredentialsInput({
|
||||
{supportsOAuth2 && (
|
||||
<OAuthFlowWaitingModal
|
||||
open={isOAuth2FlowInProgress}
|
||||
onClose={() => oAuthPopupController?.abort("canceled")}
|
||||
onClose={cancelOAuthFlow}
|
||||
providerName={providerName}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -6,7 +6,12 @@ import {
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { postV2InitiateOauthLoginForAnMcpServer } from "@/app/api/__generated__/endpoints/mcp/mcp";
|
||||
import { openOAuthPopup } from "@/lib/oauth-popup";
|
||||
import {
|
||||
OAUTH_ERROR_FLOW_CANCELED,
|
||||
OAUTH_ERROR_FLOW_TIMED_OUT,
|
||||
OAUTH_ERROR_WINDOW_CLOSED,
|
||||
openOAuthPopup,
|
||||
} from "@/lib/oauth-popup";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
@@ -49,8 +54,6 @@ export function useCredentialsInput({
|
||||
const [isCredentialTypeSelectorOpen, setCredentialTypeSelectorOpen] =
|
||||
useState(false);
|
||||
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
||||
const [oAuthPopupController, setOAuthPopupController] =
|
||||
useState<AbortController | null>(null);
|
||||
const [oAuthError, setOAuthError] = useState<string | null>(null);
|
||||
const [credentialToDelete, setCredentialToDelete] = useState<{
|
||||
id: string;
|
||||
@@ -212,12 +215,6 @@ export function useCredentialsInput({
|
||||
});
|
||||
|
||||
oauthAbortRef.current = cleanup.abort;
|
||||
// Expose abort signal for the waiting modal's cancel button
|
||||
const controller = new AbortController();
|
||||
cleanup.signal.addEventListener("abort", () =>
|
||||
controller.abort("completed"),
|
||||
);
|
||||
setOAuthPopupController(controller);
|
||||
|
||||
const result = await promise;
|
||||
|
||||
@@ -252,14 +249,16 @@ export function useCredentialsInput({
|
||||
provider,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.message === "OAuth flow timed out") {
|
||||
setOAuthError("OAuth flow timed out");
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
if (
|
||||
message === OAUTH_ERROR_WINDOW_CLOSED ||
|
||||
message === OAUTH_ERROR_FLOW_CANCELED
|
||||
) {
|
||||
// User closed the popup or clicked cancel — not an error
|
||||
} else if (message === OAUTH_ERROR_FLOW_TIMED_OUT) {
|
||||
setOAuthError(OAUTH_ERROR_FLOW_TIMED_OUT);
|
||||
} else {
|
||||
setOAuthError(
|
||||
`OAuth error: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
setOAuthError(`OAuth error: ${message}`);
|
||||
}
|
||||
} finally {
|
||||
setOAuth2FlowInProgress(false);
|
||||
@@ -311,6 +310,10 @@ export function useCredentialsInput({
|
||||
}
|
||||
}
|
||||
|
||||
function cancelOAuthFlow() {
|
||||
oauthAbortRef.current?.("canceled");
|
||||
}
|
||||
|
||||
function handleDeleteCredential(credential: { id: string; title: string }) {
|
||||
setCredentialToDelete(credential);
|
||||
}
|
||||
@@ -345,7 +348,7 @@ export function useCredentialsInput({
|
||||
isHostScopedCredentialsModalOpen,
|
||||
isCredentialTypeSelectorOpen,
|
||||
isOAuth2FlowInProgress,
|
||||
oAuthPopupController,
|
||||
cancelOAuthFlow,
|
||||
credentialToDelete,
|
||||
deleteCredentialsMutation,
|
||||
actionButtonText: getActionButtonText(
|
||||
|
||||
@@ -169,7 +169,7 @@ function renderMarkdown(
|
||||
[remarkMath, { singleDollarTextMath: false }], // Math support for LaTeX
|
||||
]}
|
||||
rehypePlugins={[
|
||||
rehypeKatex, // Render math with KaTeX
|
||||
[rehypeKatex, { strict: false }], // Render math with KaTeX
|
||||
rehypeHighlight, // Syntax highlighting for code blocks
|
||||
rehypeSlug, // Add IDs to headings
|
||||
[rehypeAutolinkHeadings, { behavior: "wrap" }], // Make headings clickable
|
||||
|
||||
@@ -28,6 +28,7 @@ export async function uploadFileDirect(
|
||||
if (sessionID) {
|
||||
url.searchParams.set("session_id", sessionID);
|
||||
}
|
||||
url.searchParams.set("overwrite", "true");
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append("file", file);
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
|
||||
const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
export const OAUTH_ERROR_WINDOW_CLOSED = "Sign-in window was closed";
|
||||
export const OAUTH_ERROR_FLOW_CANCELED = "OAuth flow was canceled";
|
||||
export const OAUTH_ERROR_FLOW_TIMED_OUT = "OAuth flow timed out";
|
||||
|
||||
export type OAuthPopupResult = {
|
||||
code: string;
|
||||
state: string;
|
||||
@@ -156,11 +160,34 @@ export function openOAuthPopup(
|
||||
);
|
||||
}
|
||||
|
||||
// Detect popup closed by user (without completing sign-in)
|
||||
if (popup) {
|
||||
const closedPollInterval = setInterval(() => {
|
||||
if (popup.closed && !handled) {
|
||||
clearInterval(closedPollInterval);
|
||||
handled = true;
|
||||
reject(new Error(OAUTH_ERROR_WINDOW_CLOSED));
|
||||
controller.abort("popup_closed");
|
||||
}
|
||||
}, 500);
|
||||
controller.signal.addEventListener("abort", () =>
|
||||
clearInterval(closedPollInterval),
|
||||
);
|
||||
}
|
||||
|
||||
// Reject on abort (e.g. from cancel button in the waiting modal)
|
||||
controller.signal.addEventListener("abort", () => {
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
reject(new Error(OAUTH_ERROR_FLOW_CANCELED));
|
||||
}
|
||||
});
|
||||
|
||||
// Timeout
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
reject(new Error("OAuth flow timed out"));
|
||||
reject(new Error(OAUTH_ERROR_FLOW_TIMED_OUT));
|
||||
controller.abort("timeout");
|
||||
}
|
||||
}, timeout);
|
||||
|
||||
@@ -24,7 +24,7 @@ test.describe("Library", () => {
|
||||
await page.goto("/library");
|
||||
|
||||
await expect(getId("search-bar").first()).toBeVisible();
|
||||
await expect(getId("upload-agent-button").first()).toBeVisible();
|
||||
await expect(getId("import-button").first()).toBeVisible();
|
||||
await expect(getId("sort-by-dropdown").first()).toBeVisible();
|
||||
});
|
||||
|
||||
@@ -171,7 +171,6 @@ test.describe("Library", () => {
|
||||
expect(matchingPaginatedResults.length).toEqual(
|
||||
allPaginatedResults.length,
|
||||
);
|
||||
} else {
|
||||
}
|
||||
|
||||
await libraryPage.scrollAndWaitForNewAgents();
|
||||
|
||||
@@ -109,19 +109,23 @@ export class LibraryPage extends BasePage {
|
||||
|
||||
async openUploadDialog(): Promise<void> {
|
||||
console.log(`opening upload dialog`);
|
||||
await this.page.getByRole("button", { name: "Upload agent" }).click();
|
||||
// Open the unified Import dialog first
|
||||
await this.page.getByRole("button", { name: "Import" }).click();
|
||||
|
||||
// Wait for dialog to appear
|
||||
await this.page.getByRole("dialog", { name: "Upload Agent" }).waitFor({
|
||||
await this.page.getByRole("dialog", { name: "Import" }).waitFor({
|
||||
state: "visible",
|
||||
timeout: 5_000,
|
||||
});
|
||||
|
||||
// Click the "AutoGPT agent" tab
|
||||
await this.page.getByRole("tab", { name: "AutoGPT agent" }).click();
|
||||
}
|
||||
|
||||
async closeUploadDialog(): Promise<void> {
|
||||
await this.page.getByRole("button", { name: "Close" }).click();
|
||||
|
||||
await this.page.getByRole("dialog", { name: "Upload Agent" }).waitFor({
|
||||
await this.page.getByRole("dialog", { name: "Import" }).waitFor({
|
||||
state: "hidden",
|
||||
timeout: 5_000,
|
||||
});
|
||||
@@ -130,7 +134,7 @@ export class LibraryPage extends BasePage {
|
||||
async isUploadDialogVisible(): Promise<boolean> {
|
||||
console.log(`checking if upload dialog is visible`);
|
||||
try {
|
||||
const dialog = this.page.getByRole("dialog", { name: "Upload Agent" });
|
||||
const dialog = this.page.getByRole("dialog", { name: "Import" });
|
||||
return await dialog.isVisible();
|
||||
} catch {
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user