mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
3 Commits
feat/githu
...
docker/opt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ed748a356 | ||
|
|
9c28639c32 | ||
|
|
4f37a12743 |
@@ -1,79 +0,0 @@
|
|||||||
---
|
|
||||||
name: pr-address
|
|
||||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
|
||||||
user-invocable: true
|
|
||||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
|
||||||
metadata:
|
|
||||||
author: autogpt-team
|
|
||||||
version: "1.0.0"
|
|
||||||
---
|
|
||||||
|
|
||||||
# PR Address
|
|
||||||
|
|
||||||
## Find the PR
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
|
||||||
gh pr view {N}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Fetch comments (all sources)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
|
||||||
```
|
|
||||||
|
|
||||||
**Bots to watch for:**
|
|
||||||
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
|
||||||
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
|
||||||
- `coderabbitai[bot]` — automated review. Address actionable items.
|
|
||||||
|
|
||||||
## For each unaddressed comment
|
|
||||||
|
|
||||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
|
||||||
|
|
||||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
|
||||||
2. Commit and push the fix
|
|
||||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
|
||||||
|
|
||||||
| Comment type | How to reply |
|
|
||||||
|---|---|
|
|
||||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
|
||||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
|
||||||
|
|
||||||
## Format and commit
|
|
||||||
|
|
||||||
After fixing, format the changed code:
|
|
||||||
|
|
||||||
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
|
||||||
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
|
||||||
|
|
||||||
If API routes changed, regenerate the frontend client:
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/backend && poetry run rest &
|
|
||||||
REST_PID=$!
|
|
||||||
trap "kill $REST_PID 2>/dev/null" EXIT
|
|
||||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
|
||||||
cd ../frontend && pnpm generate:api:force
|
|
||||||
kill $REST_PID 2>/dev/null; trap - EXIT
|
|
||||||
```
|
|
||||||
Never manually edit files in `src/app/api/__generated__/`.
|
|
||||||
|
|
||||||
Then commit and **push immediately** — never batch commits without pushing.
|
|
||||||
|
|
||||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
|
||||||
|
|
||||||
## The loop
|
|
||||||
|
|
||||||
```text
|
|
||||||
address comments → format → commit → push
|
|
||||||
→ re-check comments → fix new ones → push
|
|
||||||
→ wait for CI → re-check comments after CI settles
|
|
||||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
|
||||||
```
|
|
||||||
|
|
||||||
While CI runs, stay productive: run local tests, address remaining comments.
|
|
||||||
|
|
||||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
---
|
|
||||||
name: pr-review
|
|
||||||
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
|
||||||
user-invocable: true
|
|
||||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
|
||||||
metadata:
|
|
||||||
author: autogpt-team
|
|
||||||
version: "1.0.0"
|
|
||||||
---
|
|
||||||
|
|
||||||
# PR Review
|
|
||||||
|
|
||||||
## Find the PR
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
|
||||||
gh pr view {N}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Read the diff
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh pr diff {N}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Fetch existing review comments
|
|
||||||
|
|
||||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
|
||||||
```
|
|
||||||
|
|
||||||
## What to check
|
|
||||||
|
|
||||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
|
||||||
|
|
||||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
|
||||||
|
|
||||||
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
|
||||||
|
|
||||||
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
|
||||||
|
|
||||||
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
|
||||||
|
|
||||||
## Output format
|
|
||||||
|
|
||||||
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
|
||||||
|
|
||||||
| Tier | Badge | Meaning |
|
|
||||||
|---|---|---|
|
|
||||||
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
|
||||||
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
|
||||||
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
|
||||||
| Nit | `🔵 **Nit**` | Style / wording |
|
|
||||||
|
|
||||||
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
|
||||||
|
|
||||||
## Post inline comments
|
|
||||||
|
|
||||||
For each finding, post an inline comment on the PR (do not just write a local report):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Get the latest commit SHA for the PR
|
|
||||||
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
|
||||||
|
|
||||||
# Post an inline comment on a specific file/line
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
|
||||||
-f body="🤖 🔴 **Blocker**: <description>" \
|
|
||||||
-f commit_id="$COMMIT_SHA" \
|
|
||||||
-f path="<file path>" \
|
|
||||||
-F line=<line number>
|
|
||||||
```
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
---
|
|
||||||
name: worktree
|
|
||||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
|
||||||
user-invocable: true
|
|
||||||
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
|
||||||
metadata:
|
|
||||||
author: autogpt-team
|
|
||||||
version: "3.0.0"
|
|
||||||
---
|
|
||||||
|
|
||||||
# Worktree Setup
|
|
||||||
|
|
||||||
## Create the worktree
|
|
||||||
|
|
||||||
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ROOT=$(git rev-parse --show-toplevel)
|
|
||||||
PARENT=$(dirname "$ROOT")
|
|
||||||
|
|
||||||
# From an existing branch
|
|
||||||
git worktree add "$PARENT/<NAME>" <branch-name>
|
|
||||||
|
|
||||||
# From a new branch off dev
|
|
||||||
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
|
||||||
```
|
|
||||||
|
|
||||||
## Copy environment files
|
|
||||||
|
|
||||||
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ROOT=$(git rev-parse --show-toplevel)
|
|
||||||
TARGET="$(dirname "$ROOT")/<NAME>"
|
|
||||||
|
|
||||||
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
|
||||||
if [ -f "$ROOT/$envpath/.env" ]; then
|
|
||||||
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
|
||||||
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
|
||||||
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
```
|
|
||||||
|
|
||||||
## Install dependencies
|
|
||||||
|
|
||||||
```bash
|
|
||||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
|
||||||
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
|
||||||
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
|
||||||
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
|
||||||
```
|
|
||||||
|
|
||||||
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
|
||||||
|
|
||||||
## Running the app (optional)
|
|
||||||
|
|
||||||
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
|
||||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
|
||||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
|
||||||
done
|
|
||||||
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
|
||||||
```
|
|
||||||
|
|
||||||
## CoPilot testing
|
|
||||||
|
|
||||||
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
|
||||||
|
|
||||||
## Cleanup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
|
||||||
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Alternative: Branchlet (optional)
|
|
||||||
|
|
||||||
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
|
||||||
```
|
|
||||||
@@ -5,13 +5,42 @@
|
|||||||
!docs/
|
!docs/
|
||||||
|
|
||||||
# Platform - Libs
|
# Platform - Libs
|
||||||
!autogpt_platform/autogpt_libs/
|
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||||
|
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||||
|
!autogpt_platform/autogpt_libs/poetry.lock
|
||||||
|
!autogpt_platform/autogpt_libs/README.md
|
||||||
|
|
||||||
# Platform - Backend
|
# Platform - Backend
|
||||||
!autogpt_platform/backend/
|
!autogpt_platform/backend/backend/
|
||||||
|
!autogpt_platform/backend/test/e2e_test_data.py
|
||||||
|
!autogpt_platform/backend/migrations/
|
||||||
|
!autogpt_platform/backend/schema.prisma
|
||||||
|
!autogpt_platform/backend/pyproject.toml
|
||||||
|
!autogpt_platform/backend/poetry.lock
|
||||||
|
!autogpt_platform/backend/README.md
|
||||||
|
!autogpt_platform/backend/.env
|
||||||
|
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||||
|
|
||||||
|
# Platform - Market
|
||||||
|
!autogpt_platform/market/market/
|
||||||
|
!autogpt_platform/market/scripts.py
|
||||||
|
!autogpt_platform/market/schema.prisma
|
||||||
|
!autogpt_platform/market/pyproject.toml
|
||||||
|
!autogpt_platform/market/poetry.lock
|
||||||
|
!autogpt_platform/market/README.md
|
||||||
|
|
||||||
# Platform - Frontend
|
# Platform - Frontend
|
||||||
!autogpt_platform/frontend/
|
!autogpt_platform/frontend/src/
|
||||||
|
!autogpt_platform/frontend/public/
|
||||||
|
!autogpt_platform/frontend/scripts/
|
||||||
|
!autogpt_platform/frontend/package.json
|
||||||
|
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
|
!autogpt_platform/frontend/tsconfig.json
|
||||||
|
!autogpt_platform/frontend/README.md
|
||||||
|
## config
|
||||||
|
!autogpt_platform/frontend/*.config.*
|
||||||
|
!autogpt_platform/frontend/.env.*
|
||||||
|
!autogpt_platform/frontend/.env
|
||||||
|
|
||||||
# Classic - AutoGPT
|
# Classic - AutoGPT
|
||||||
!classic/original_autogpt/autogpt/
|
!classic/original_autogpt/autogpt/
|
||||||
@@ -35,38 +64,6 @@
|
|||||||
# Classic - Frontend
|
# Classic - Frontend
|
||||||
!classic/frontend/build/web/
|
!classic/frontend/build/web/
|
||||||
|
|
||||||
# Explicitly re-ignore unwanted files from whitelisted directories
|
# Explicitly re-ignore some folders
|
||||||
# Note: These patterns MUST come after the whitelist rules to take effect
|
.*
|
||||||
|
**/__pycache__
|
||||||
# Hidden files and directories (but keep frontend .env files needed for build)
|
|
||||||
**/.*
|
|
||||||
!autogpt_platform/frontend/.env
|
|
||||||
!autogpt_platform/frontend/.env.default
|
|
||||||
!autogpt_platform/frontend/.env.production
|
|
||||||
|
|
||||||
# Python artifacts
|
|
||||||
**/__pycache__/
|
|
||||||
**/*.pyc
|
|
||||||
**/*.pyo
|
|
||||||
**/.venv/
|
|
||||||
**/.ruff_cache/
|
|
||||||
**/.pytest_cache/
|
|
||||||
**/.coverage
|
|
||||||
**/htmlcov/
|
|
||||||
|
|
||||||
# Node artifacts
|
|
||||||
**/node_modules/
|
|
||||||
**/.next/
|
|
||||||
**/storybook-static/
|
|
||||||
**/playwright-report/
|
|
||||||
**/test-results/
|
|
||||||
|
|
||||||
# Build artifacts
|
|
||||||
**/dist/
|
|
||||||
**/build/
|
|
||||||
!autogpt_platform/frontend/src/**/build/
|
|
||||||
**/target/
|
|
||||||
|
|
||||||
# Logs and temp files
|
|
||||||
**/*.log
|
|
||||||
**/*.tmp
|
|
||||||
|
|||||||
1229
.github/scripts/detect_overlaps.py
vendored
1229
.github/scripts/detect_overlaps.py
vendored
File diff suppressed because it is too large
Load Diff
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v8
|
uses: peter-evans/create-pull-request@v7
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
46
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
46
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.workflow_run.head_branch }}
|
ref: ${{ github.event.workflow_run.head_branch }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
@@ -40,51 +40,9 @@ jobs:
|
|||||||
git checkout -b "$BRANCH_NAME"
|
git checkout -b "$BRANCH_NAME"
|
||||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
# Backend Python/Poetry setup (so Claude can run linting/tests)
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pypoetry
|
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
run: |
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
|
||||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
|
||||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
|
||||||
|
|
||||||
- name: Install Python dependencies
|
|
||||||
working-directory: autogpt_platform/backend
|
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Generate Prisma Client
|
|
||||||
working-directory: autogpt_platform/backend
|
|
||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
|
|
||||||
- name: Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22"
|
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Install JavaScript dependencies
|
|
||||||
working-directory: autogpt_platform/frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
29
.github/workflows/claude-dependabot.yml
vendored
29
.github/workflows/claude-dependabot.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -77,15 +77,27 @@ jobs:
|
|||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22"
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set pnpm store directory
|
||||||
uses: actions/setup-node@v6
|
run: |
|
||||||
|
pnpm config set store-dir ~/.pnpm-store
|
||||||
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache frontend dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install JavaScript dependencies
|
- name: Install JavaScript dependencies
|
||||||
working-directory: autogpt_platform/frontend
|
working-directory: autogpt_platform/frontend
|
||||||
@@ -112,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -297,7 +309,6 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
allowed_bots: "dependabot[bot]"
|
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
28
.github/workflows/claude.yml
vendored
28
.github/workflows/claude.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -93,15 +93,27 @@ jobs:
|
|||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22"
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set pnpm store directory
|
||||||
uses: actions/setup-node@v6
|
run: |
|
||||||
|
pnpm config set store-dir ~/.pnpm-store
|
||||||
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache frontend dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install JavaScript dependencies
|
- name: Install JavaScript dependencies
|
||||||
working-directory: autogpt_platform/frontend
|
working-directory: autogpt_platform/frontend
|
||||||
@@ -128,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
6
.github/workflows/codeql.yml
vendored
6
.github/workflows/codeql.yml
vendored
@@ -58,11 +58,11 @@ jobs:
|
|||||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v4
|
uses: github/codeql-action/init@v3
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
build-mode: ${{ matrix.build-mode }}
|
build-mode: ${{ matrix.build-mode }}
|
||||||
@@ -93,6 +93,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v4
|
uses: github/codeql-action/analyze@v3
|
||||||
with:
|
with:
|
||||||
category: "/language:${{matrix.language}}"
|
category: "/language:${{matrix.language}}"
|
||||||
|
|||||||
10
.github/workflows/copilot-setup-steps.yml
vendored
10
.github/workflows/copilot-setup-steps.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
# If you do not check out your code, Copilot will do this for you.
|
# If you do not check out your code, Copilot will do this for you.
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
4
.github/workflows/docs-block-sync.yml
vendored
4
.github/workflows/docs-block-sync.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
38
.github/workflows/docs-claude-review.yml
vendored
38
.github/workflows/docs-claude-review.yml
vendored
@@ -7,10 +7,6 @@ on:
|
|||||||
- "docs/integrations/**"
|
- "docs/integrations/**"
|
||||||
- "autogpt_platform/backend/backend/blocks/**"
|
- "autogpt_platform/backend/backend/blocks/**"
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: claude-docs-review-${{ github.event.pull_request.number }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
claude-review:
|
claude-review:
|
||||||
# Only run for PRs from members/collaborators
|
# Only run for PRs from members/collaborators
|
||||||
@@ -27,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
@@ -37,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -95,35 +91,5 @@ jobs:
|
|||||||
3. Read corresponding documentation files to verify accuracy
|
3. Read corresponding documentation files to verify accuracy
|
||||||
4. Provide your feedback as a PR comment
|
4. Provide your feedback as a PR comment
|
||||||
|
|
||||||
## IMPORTANT: Comment Marker
|
|
||||||
Start your PR comment with exactly this HTML comment marker on its own line:
|
|
||||||
<!-- CLAUDE_DOCS_REVIEW -->
|
|
||||||
|
|
||||||
This marker is used to identify and replace your comment on subsequent runs.
|
|
||||||
|
|
||||||
Be constructive and specific. If everything looks good, say so!
|
Be constructive and specific. If everything looks good, say so!
|
||||||
If there are issues, explain what's wrong and suggest how to fix it.
|
If there are issues, explain what's wrong and suggest how to fix it.
|
||||||
|
|
||||||
- name: Delete old Claude review comments
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: |
|
|
||||||
# Get all comment IDs with our marker, sorted by creation date (oldest first)
|
|
||||||
COMMENT_IDS=$(gh api \
|
|
||||||
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
|
|
||||||
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
|
|
||||||
|
|
||||||
# Count comments
|
|
||||||
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
|
|
||||||
|
|
||||||
if [ "$COMMENT_COUNT" -gt 1 ]; then
|
|
||||||
# Delete all but the last (newest) comment
|
|
||||||
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
|
|
||||||
if [ -n "$COMMENT_ID" ]; then
|
|
||||||
echo "Deleting old review comment: $COMMENT_ID"
|
|
||||||
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
else
|
|
||||||
echo "No old review comments to clean up"
|
|
||||||
fi
|
|
||||||
|
|||||||
4
.github/workflows/docs-enhance.yml
vendored
4
.github/workflows/docs-enhance.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name || 'master' }}
|
ref: ${{ github.ref_name || 'master' }}
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger deploy workflow
|
- name: Trigger deploy workflow
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
|
|||||||
13
.github/workflows/platform-backend-ci.yml
vendored
13
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,18 +41,13 @@ jobs:
|
|||||||
ports:
|
ports:
|
||||||
- 6379:6379
|
- 6379:6379
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:4.1.4
|
image: rabbitmq:3.12-management
|
||||||
ports:
|
ports:
|
||||||
- 5672:5672
|
- 5672:5672
|
||||||
|
- 15672:15672
|
||||||
env:
|
env:
|
||||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||||
options: >-
|
|
||||||
--health-cmd "rabbitmq-diagnostics -q ping"
|
|
||||||
--health-interval 30s
|
|
||||||
--health-timeout 10s
|
|
||||||
--health-retries 5
|
|
||||||
--health-start-period 10s
|
|
||||||
clamav:
|
clamav:
|
||||||
image: clamav/clamav-debian:latest
|
image: clamav/clamav-debian:latest
|
||||||
ports:
|
ports:
|
||||||
@@ -73,7 +68,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
@@ -93,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Deploy Event
|
- name: Dispatch Deploy Event
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -110,7 +110,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Dispatch Undeploy Event (from comment)
|
- name: Dispatch Undeploy Event (from comment)
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -168,7 +168,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: peter-evans/repository-dispatch@v4
|
uses: peter-evans/repository-dispatch@v3
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
275
.github/workflows/platform-frontend-ci.yml
vendored
275
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,16 +6,10 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
- "autogpt_platform/backend/Dockerfile"
|
|
||||||
- "autogpt_platform/docker-compose.yml"
|
|
||||||
- "autogpt_platform/docker-compose.platform.yml"
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
- "autogpt_platform/backend/Dockerfile"
|
|
||||||
- "autogpt_platform/docker-compose.yml"
|
|
||||||
- "autogpt_platform/docker-compose.platform.yml"
|
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
@@ -32,31 +26,34 @@ jobs:
|
|||||||
setup:
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for component changes
|
- name: Set up Node.js
|
||||||
uses: dorny/paths-filter@v3
|
uses: actions/setup-node@v4
|
||||||
id: filter
|
|
||||||
with:
|
with:
|
||||||
filters: |
|
node-version: "22.18.0"
|
||||||
components:
|
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node
|
- name: Generate cache key
|
||||||
uses: actions/setup-node@v6
|
id: cache-key
|
||||||
with:
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
node-version: "22.18.0"
|
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Install dependencies to populate cache
|
- name: Cache dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
@@ -65,17 +62,24 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node
|
- name: Restore dependencies cache
|
||||||
uses: actions/setup-node@v6
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -86,27 +90,31 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
# Only run on dev branch pushes or PRs targeting dev
|
||||||
if: >-
|
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||||
false
|
|
||||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
|
||||||
&& needs.setup.outputs.components-changed == 'true'
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node
|
- name: Restore dependencies cache
|
||||||
uses: actions/setup-node@v6
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -121,20 +129,30 @@ jobs:
|
|||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
e2e_test:
|
e2e_test:
|
||||||
name: end-to-end tests
|
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
|
needs: setup
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Platform - Copy default supabase .env
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Copy default supabase .env
|
||||||
run: |
|
run: |
|
||||||
cp ../.env.default ../.env
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
- name: Copy backend .env and set OpenAI API key
|
||||||
run: |
|
run: |
|
||||||
cp ../backend/.env.default ../backend/.env
|
cp ../backend/.env.default ../backend/.env
|
||||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||||
@@ -142,125 +160,77 @@ jobs:
|
|||||||
# Used by E2E test data script to generate embeddings for approved store agents
|
# Used by E2E test data script to generate embeddings for approved store agents
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
- name: Set up Platform - Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Cache Docker layers
|
||||||
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
driver: docker-container
|
path: /tmp/.buildx-cache
|
||||||
driver-opts: network=host
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-buildx-frontend-test-
|
||||||
|
|
||||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
- name: Run docker compose
|
||||||
uses: crazy-max/ghaction-github-runtime@v4
|
|
||||||
|
|
||||||
- name: Set up Platform - Build Docker images (with cache)
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
run: |
|
||||||
pip install pyyaml
|
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||||
|
|
||||||
# Resolve extends and generate a flat compose file that bake can understand
|
|
||||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
|
||||||
|
|
||||||
# Add cache configuration to the resolved compose file
|
|
||||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
|
||||||
--source docker-compose.resolved.yml \
|
|
||||||
--cache-from "type=gha" \
|
|
||||||
--cache-to "type=gha,mode=max" \
|
|
||||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
|
||||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
|
||||||
--git-ref "${{ github.ref }}"
|
|
||||||
|
|
||||||
# Build with bake using the resolved compose file (now includes cache config)
|
|
||||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
|
||||||
env:
|
env:
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
DOCKER_BUILDKIT: 1
|
||||||
|
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||||
|
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||||
|
|
||||||
- name: Set up tests - Cache E2E test data
|
- name: Move cache
|
||||||
id: e2e-data-cache
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
|
||||||
path: /tmp/e2e_test_data.sql
|
|
||||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
|
||||||
|
|
||||||
- name: Set up Platform - Start Supabase DB + Auth
|
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
rm -rf /tmp/.buildx-cache
|
||||||
echo "Waiting for database to be ready..."
|
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||||
echo "Waiting for auth service to be ready..."
|
fi
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
|
||||||
|
|
||||||
- name: Set up Platform - Run migrations
|
- name: Wait for services to be ready
|
||||||
run: |
|
run: |
|
||||||
echo "Running migrations..."
|
|
||||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
|
||||||
echo "✅ Migrations completed"
|
|
||||||
env:
|
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
|
||||||
|
|
||||||
- name: Set up tests - Load cached E2E test data
|
|
||||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
|
||||||
run: |
|
|
||||||
echo "✅ Found cached E2E test data, restoring..."
|
|
||||||
{
|
|
||||||
echo "SET session_replication_role = 'replica';"
|
|
||||||
cat /tmp/e2e_test_data.sql
|
|
||||||
echo "SET session_replication_role = 'origin';"
|
|
||||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
|
||||||
# Refresh materialized views after restore
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
|
||||||
|
|
||||||
echo "✅ E2E test data restored from cache"
|
|
||||||
|
|
||||||
- name: Set up Platform - Start (all other services)
|
|
||||||
run: |
|
|
||||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
|
||||||
echo "Waiting for rest_server to be ready..."
|
echo "Waiting for rest_server to be ready..."
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
env:
|
echo "Waiting for database to be ready..."
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||||
|
|
||||||
- name: Set up tests - Create E2E test data
|
- name: Create E2E test data
|
||||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
run: |
|
||||||
echo "Creating E2E test data..."
|
echo "Creating E2E test data..."
|
||||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
# First try to run the script from inside the container
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||||
echo "❌ E2E test data creation failed!"
|
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||||
exit 1
|
echo "❌ E2E test data creation failed!"
|
||||||
}
|
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
else
|
||||||
|
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||||
|
# Copy the script into the container and run it
|
||||||
|
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||||
|
echo "❌ Failed to copy script to container"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||||
|
echo "❌ E2E test data creation failed!"
|
||||||
|
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
fi
|
||||||
|
|
||||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
- name: Restore dependencies cache
|
||||||
echo "Dumping database for cache..."
|
uses: actions/cache@v4
|
||||||
{
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
pg_dump -U postgres --data-only --column-inserts \
|
|
||||||
--table='auth.users' postgres
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
pg_dump -U postgres --data-only --column-inserts \
|
|
||||||
--schema=platform \
|
|
||||||
--exclude-table='platform._prisma_migrations' \
|
|
||||||
--exclude-table='platform.apscheduler_jobs' \
|
|
||||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
|
||||||
postgres
|
|
||||||
} > /tmp/e2e_test_data.sql
|
|
||||||
|
|
||||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
|
||||||
|
|
||||||
- name: Set up tests - Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up tests - Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Set up tests - Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Set up tests - Install browser 'chromium'
|
- name: Install Browser 'chromium'
|
||||||
run: pnpm playwright install --with-deps chromium
|
run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
@@ -287,7 +257,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
run: docker compose -f ../docker-compose.yml logs
|
||||||
|
|
||||||
integration_test:
|
integration_test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -295,19 +265,26 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node
|
- name: Restore dependencies cache
|
||||||
uses: actions/setup-node@v6
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|||||||
16
.github/workflows/platform-fullstack-ci.yml
vendored
16
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -56,19 +56,19 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: big-boi
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
39
.github/workflows/pr-overlap-check.yml
vendored
39
.github/workflows/pr-overlap-check.yml
vendored
@@ -1,39 +0,0 @@
|
|||||||
name: PR Overlap Detection
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened]
|
|
||||||
branches:
|
|
||||||
- dev
|
|
||||||
- master
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-overlaps:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # Need full history for merge testing
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Configure git
|
|
||||||
run: |
|
|
||||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
|
||||||
git config user.name "github-actions[bot]"
|
|
||||||
|
|
||||||
- name: Run overlap detection
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
# Always succeed - this check informs contributors, it shouldn't block merging
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}
|
|
||||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
# - name: Wait some time for all actions to start
|
# - name: Wait some time for all actions to start
|
||||||
# run: sleep 30
|
# run: sleep 30
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v4
|
||||||
# with:
|
# with:
|
||||||
# fetch-depth: 0
|
# fetch-depth: 0
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|||||||
@@ -1,195 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Add cache configuration to a resolved docker-compose file for all services
|
|
||||||
that have a build key, and ensure image names match what docker compose expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BRANCH = "dev"
|
|
||||||
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Add cache config to a resolved compose file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--source",
|
|
||||||
required=True,
|
|
||||||
help="Source compose file to read (should be output of `docker compose config`)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache-from",
|
|
||||||
default="type=gha",
|
|
||||||
help="Cache source configuration",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache-to",
|
|
||||||
default="type=gha,mode=max",
|
|
||||||
help="Cache destination configuration",
|
|
||||||
)
|
|
||||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{component}-hash",
|
|
||||||
default="",
|
|
||||||
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--git-ref",
|
|
||||||
default="",
|
|
||||||
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
|
||||||
git_ref_scope = ""
|
|
||||||
if args.git_ref:
|
|
||||||
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
|
||||||
|
|
||||||
with open(args.source, "r") as f:
|
|
||||||
compose = yaml.safe_load(f)
|
|
||||||
|
|
||||||
# Get project name from compose file or default
|
|
||||||
project_name = compose.get("name", "autogpt_platform")
|
|
||||||
|
|
||||||
def get_image_name(dockerfile: str, target: str) -> str:
|
|
||||||
"""Generate image name based on Dockerfile folder and build target."""
|
|
||||||
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
|
||||||
if len(dockerfile_parts) >= 2:
|
|
||||||
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
|
||||||
else:
|
|
||||||
folder_name = "app"
|
|
||||||
return f"{project_name}-{folder_name}:{target}"
|
|
||||||
|
|
||||||
def get_build_key(dockerfile: str, target: str) -> str:
|
|
||||||
"""Generate a unique key for a Dockerfile+target combination."""
|
|
||||||
return f"{dockerfile}:{target}"
|
|
||||||
|
|
||||||
def get_component(dockerfile: str) -> str | None:
|
|
||||||
"""Get component name (frontend/backend) from dockerfile path."""
|
|
||||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
|
||||||
if component in dockerfile:
|
|
||||||
return component
|
|
||||||
return None
|
|
||||||
|
|
||||||
# First pass: collect all services with build configs and identify duplicates
|
|
||||||
# Track which (dockerfile, target) combinations we've seen
|
|
||||||
build_key_to_first_service: dict[str, str] = {}
|
|
||||||
services_to_build: list[str] = []
|
|
||||||
services_to_dedupe: list[str] = []
|
|
||||||
|
|
||||||
for service_name, service_config in compose.get("services", {}).items():
|
|
||||||
if "build" not in service_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
build_config = service_config["build"]
|
|
||||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
|
||||||
target = build_config.get("target", "default")
|
|
||||||
build_key = get_build_key(dockerfile, target)
|
|
||||||
|
|
||||||
if build_key not in build_key_to_first_service:
|
|
||||||
# First service with this build config - it will do the actual build
|
|
||||||
build_key_to_first_service[build_key] = service_name
|
|
||||||
services_to_build.append(service_name)
|
|
||||||
else:
|
|
||||||
# Duplicate - will just use the image from the first service
|
|
||||||
services_to_dedupe.append(service_name)
|
|
||||||
|
|
||||||
# Second pass: configure builds and deduplicate
|
|
||||||
modified_services = []
|
|
||||||
for service_name, service_config in compose.get("services", {}).items():
|
|
||||||
if "build" not in service_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
build_config = service_config["build"]
|
|
||||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
|
||||||
target = build_config.get("target", "latest")
|
|
||||||
image_name = get_image_name(dockerfile, target)
|
|
||||||
|
|
||||||
# Set image name for all services (needed for both builders and deduped)
|
|
||||||
service_config["image"] = image_name
|
|
||||||
|
|
||||||
if service_name in services_to_dedupe:
|
|
||||||
# Remove build config - this service will use the pre-built image
|
|
||||||
del service_config["build"]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# This service will do the actual build - add cache config
|
|
||||||
cache_from_list = []
|
|
||||||
cache_to_list = []
|
|
||||||
|
|
||||||
component = get_component(dockerfile)
|
|
||||||
if not component:
|
|
||||||
# Skip services that don't clearly match frontend/backend
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the hash for this component
|
|
||||||
component_hash = getattr(args, f"{component}_hash")
|
|
||||||
|
|
||||||
# Scope format: platform-{component}-{target}-{hash|ref}
|
|
||||||
# Example: platform-backend-server-abc123
|
|
||||||
|
|
||||||
if "type=gha" in args.cache_from:
|
|
||||||
# 1. Primary: exact hash match (most specific)
|
|
||||||
if component_hash:
|
|
||||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
|
||||||
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
|
||||||
|
|
||||||
# 2. Fallback: branch-based cache
|
|
||||||
if git_ref_scope:
|
|
||||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
|
||||||
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
|
||||||
|
|
||||||
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
|
||||||
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
|
||||||
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
|
||||||
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
|
||||||
|
|
||||||
if "type=gha" in args.cache_to:
|
|
||||||
# Write to both hash-based and branch-based scopes
|
|
||||||
if component_hash:
|
|
||||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
|
||||||
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
|
||||||
|
|
||||||
if git_ref_scope:
|
|
||||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
|
||||||
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
|
||||||
|
|
||||||
# Ensure we have at least one cache source/target
|
|
||||||
if not cache_from_list:
|
|
||||||
cache_from_list.append(args.cache_from)
|
|
||||||
if not cache_to_list:
|
|
||||||
cache_to_list.append(args.cache_to)
|
|
||||||
|
|
||||||
build_config["cache_from"] = cache_from_list
|
|
||||||
build_config["cache_to"] = cache_to_list
|
|
||||||
modified_services.append(service_name)
|
|
||||||
|
|
||||||
# Write back to the same file
|
|
||||||
with open(args.source, "w") as f:
|
|
||||||
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
|
||||||
|
|
||||||
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
|
||||||
for svc in modified_services:
|
|
||||||
svc_config = compose["services"][svc]
|
|
||||||
build_cfg = svc_config.get("build", {})
|
|
||||||
cache_from_list = build_cfg.get("cache_from", ["none"])
|
|
||||||
cache_to_list = build_cfg.get("cache_to", ["none"])
|
|
||||||
print(f" - {svc}")
|
|
||||||
print(f" image: {svc_config.get('image', 'N/A')}")
|
|
||||||
print(f" cache_from: {cache_from_list}")
|
|
||||||
print(f" cache_to: {cache_to_list}")
|
|
||||||
if services_to_dedupe:
|
|
||||||
print(
|
|
||||||
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
|
||||||
)
|
|
||||||
for svc in services_to_dedupe:
|
|
||||||
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -180,6 +180,3 @@ autogpt_platform/backend/settings.py
|
|||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
.next
|
|
||||||
# Implementation plans (generated by AI agents)
|
|
||||||
plans/
|
|
||||||
|
|||||||
@@ -1,10 +1,3 @@
|
|||||||
default_install_hook_types:
|
|
||||||
- pre-commit
|
|
||||||
- pre-push
|
|
||||||
- post-checkout
|
|
||||||
|
|
||||||
default_stages: [pre-commit]
|
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.4.0
|
||||||
@@ -24,7 +17,6 @@ repos:
|
|||||||
name: Detect secrets
|
name: Detect secrets
|
||||||
description: Detects high entropy strings that are likely to be passwords.
|
description: Detects high entropy strings that are likely to be passwords.
|
||||||
files: ^autogpt_platform/
|
files: ^autogpt_platform/
|
||||||
exclude: pnpm-lock\.yaml$
|
|
||||||
stages: [pre-push]
|
stages: [pre-push]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
@@ -34,106 +26,49 @@ repos:
|
|||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Backend
|
name: Check & Install dependencies - AutoGPT Platform - Backend
|
||||||
alias: poetry-install-platform-backend
|
alias: poetry-install-platform-backend
|
||||||
|
entry: poetry -C autogpt_platform/backend install
|
||||||
# include autogpt_libs source (since it's a path dependency)
|
# include autogpt_libs source (since it's a path dependency)
|
||||||
entry: >
|
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
|
||||||
bash -c '
|
types: [file]
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C autogpt_platform/backend install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Libs
|
name: Check & Install dependencies - AutoGPT Platform - Libs
|
||||||
alias: poetry-install-platform-libs
|
alias: poetry-install-platform-libs
|
||||||
entry: >
|
entry: poetry -C autogpt_platform/autogpt_libs install
|
||||||
bash -c '
|
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
types: [file]
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C autogpt_platform/autogpt_libs install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: pnpm-install
|
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Frontend
|
|
||||||
alias: pnpm-install-platform-frontend
|
|
||||||
entry: >
|
|
||||||
bash -c '
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
|
|
||||||
pnpm --prefix autogpt_platform/frontend install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - AutoGPT
|
name: Check & Install dependencies - Classic - AutoGPT
|
||||||
alias: poetry-install-classic-autogpt
|
alias: poetry-install-classic-autogpt
|
||||||
entry: >
|
entry: poetry -C classic/original_autogpt install
|
||||||
bash -c '
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C classic/original_autogpt install
|
|
||||||
'
|
|
||||||
# include forge source (since it's a path dependency)
|
# include forge source (since it's a path dependency)
|
||||||
always_run: true
|
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - Forge
|
name: Check & Install dependencies - Classic - Forge
|
||||||
alias: poetry-install-classic-forge
|
alias: poetry-install-classic-forge
|
||||||
entry: >
|
entry: poetry -C classic/forge install
|
||||||
bash -c '
|
files: ^classic/forge/poetry\.lock$
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
types: [file]
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C classic/forge install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - Benchmark
|
name: Check & Install dependencies - Classic - Benchmark
|
||||||
alias: poetry-install-classic-benchmark
|
alias: poetry-install-classic-benchmark
|
||||||
entry: >
|
entry: poetry -C classic/benchmark install
|
||||||
bash -c '
|
files: ^classic/benchmark/poetry\.lock$
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
types: [file]
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C classic/benchmark install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
# For proper type checking, Prisma client must be up-to-date.
|
# For proper type checking, Prisma client must be up-to-date.
|
||||||
@@ -141,54 +76,12 @@ repos:
|
|||||||
- id: prisma-generate
|
- id: prisma-generate
|
||||||
name: Prisma Generate - AutoGPT Platform - Backend
|
name: Prisma Generate - AutoGPT Platform - Backend
|
||||||
alias: prisma-generate-platform-backend
|
alias: prisma-generate-platform-backend
|
||||||
entry: >
|
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
|
||||||
bash -c '
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
&& poetry run prisma generate
|
|
||||||
&& poetry run gen-prisma-stub
|
|
||||||
'
|
|
||||||
# include everything that triggers poetry install + the prisma schema
|
# include everything that triggers poetry install + the prisma schema
|
||||||
always_run: true
|
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
|
||||||
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: export-api-schema
|
|
||||||
name: Export API schema - AutoGPT Platform - Backend -> Frontend
|
|
||||||
alias: export-api-schema-platform
|
|
||||||
entry: >
|
|
||||||
bash -c '
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
|
||||||
&& cd ../frontend
|
|
||||||
&& pnpm prettier --write ./src/app/api/openapi.json
|
|
||||||
'
|
|
||||||
files: ^autogpt_platform/backend/
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
|
|
||||||
- id: generate-api-client
|
|
||||||
name: Generate API client - AutoGPT Platform - Frontend
|
|
||||||
alias: generate-api-client-platform-frontend
|
|
||||||
entry: >
|
|
||||||
bash -c '
|
|
||||||
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
|
|
||||||
else
|
|
||||||
git diff --quiet HEAD -- "$SCHEMA" && exit 0
|
|
||||||
fi;
|
|
||||||
cd autogpt_platform/frontend && pnpm generate:api
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.7.2
|
rev: v0.7.2
|
||||||
|
|||||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,2 @@
|
|||||||
*.ignore.*
|
*.ignore.*
|
||||||
*.ign.*
|
*.ign.*
|
||||||
.application.logs
|
|
||||||
@@ -45,11 +45,6 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
### Branching Strategy
|
|
||||||
|
|
||||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
|
||||||
- **`master`** is the production branch. Only used for production releases.
|
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR against the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
@@ -60,12 +55,9 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|
||||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||||
|
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||||
When fetching comments manually:
|
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
|
||||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
|
||||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
|
||||||
|
|
||||||
### Conventional Commits
|
### Conventional Commits
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.auth_activities
|
|
||||||
-- Looker source alias: ds49 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Tracks authentication events (login, logout, SSO, password
|
|
||||||
-- reset, etc.) from Supabase's internal audit log.
|
|
||||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.audit_log_entries — Supabase internal auth event log
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
|
||||||
-- actor_id TEXT User ID who triggered the event
|
|
||||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
|
||||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days from current date
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Daily login counts
|
|
||||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
|
||||||
-- FROM analytics.auth_activities
|
|
||||||
-- WHERE action = 'login'
|
|
||||||
-- GROUP BY 1 ORDER BY 1;
|
|
||||||
--
|
|
||||||
-- -- SSO vs password login breakdown
|
|
||||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
|
||||||
-- WHERE action = 'login' GROUP BY 1;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
created_at,
|
|
||||||
payload->>'actor_id' AS actor_id,
|
|
||||||
payload->>'actor_via_sso' AS actor_via_sso,
|
|
||||||
payload->>'action' AS action
|
|
||||||
FROM auth.audit_log_entries
|
|
||||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.graph_execution
|
|
||||||
-- Looker source alias: ds16 | Charts: 21
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per agent graph execution (last 90 days).
|
|
||||||
-- Unpacks the JSONB stats column into individual numeric columns
|
|
||||||
-- and normalises the executionStatus — runs that failed due to
|
|
||||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
|
||||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
|
||||||
-- to allow safe grouping.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records
|
|
||||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
|
||||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- id TEXT Execution UUID
|
|
||||||
-- agentGraphId TEXT Agent graph UUID
|
|
||||||
-- agentGraphVersion INT Graph version number
|
|
||||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
|
||||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
|
||||||
-- updatedAt TIMESTAMPTZ Last status update time
|
|
||||||
-- userId TEXT Owner user UUID
|
|
||||||
-- agentGraphName TEXT Human-readable agent name
|
|
||||||
-- cputime DECIMAL Total CPU seconds consumed
|
|
||||||
-- walltime DECIMAL Total wall-clock seconds
|
|
||||||
-- node_count DECIMAL Number of nodes in the graph
|
|
||||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
|
||||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
|
||||||
-- execution_cost DECIMAL Credit cost of this execution
|
|
||||||
-- correctness_score FLOAT AI correctness score (if available)
|
|
||||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
|
||||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Daily execution counts by status
|
|
||||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
|
||||||
-- FROM analytics.graph_execution
|
|
||||||
-- GROUP BY 1, 2 ORDER BY 1;
|
|
||||||
--
|
|
||||||
-- -- Average cost per execution by agent
|
|
||||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
|
||||||
-- FROM analytics.graph_execution
|
|
||||||
-- WHERE "executionStatus" = 'COMPLETED'
|
|
||||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
|
||||||
--
|
|
||||||
-- -- Top error messages
|
|
||||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
|
||||||
-- FROM analytics.graph_execution
|
|
||||||
-- WHERE "executionStatus" = 'FAILED'
|
|
||||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
ge."id" AS id,
|
|
||||||
ge."agentGraphId" AS agentGraphId,
|
|
||||||
ge."agentGraphVersion" AS agentGraphVersion,
|
|
||||||
CASE
|
|
||||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
|
||||||
AND (
|
|
||||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
|
||||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
|
||||||
)
|
|
||||||
THEN 'NO_CREDITS'
|
|
||||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
|
||||||
END AS executionStatus,
|
|
||||||
ge."createdAt" AS createdAt,
|
|
||||||
ge."updatedAt" AS updatedAt,
|
|
||||||
ge."userId" AS userId,
|
|
||||||
g."name" AS agentGraphName,
|
|
||||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
|
||||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
|
||||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
|
||||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
|
||||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
|
||||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
|
||||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
|
||||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
|
||||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
|
||||||
'\1\2/...', 'gi'
|
|
||||||
),
|
|
||||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
|
||||||
) AS groupedErrorMessage
|
|
||||||
FROM platform."AgentGraphExecution" ge
|
|
||||||
LEFT JOIN platform."AgentGraph" g
|
|
||||||
ON ge."agentGraphId" = g."id"
|
|
||||||
AND ge."agentGraphVersion" = g."version"
|
|
||||||
LEFT JOIN (
|
|
||||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
|
||||||
"userId", "agentGraphId",
|
|
||||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
|
||||||
FROM platform."LibraryAgent"
|
|
||||||
WHERE "isDeleted" = FALSE
|
|
||||||
AND "isArchived" = FALSE
|
|
||||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
|
||||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
|
||||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.node_block_execution
|
|
||||||
-- Looker source alias: ds14 | Charts: 11
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per node (block) execution (last 90 days).
|
|
||||||
-- Unpacks stats JSONB and joins to identify which block type
|
|
||||||
-- was run. For failed nodes, joins the error output and
|
|
||||||
-- scrubs it for safe grouping.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentNodeExecution — Node execution records
|
|
||||||
-- platform.AgentNode — Node → block mapping
|
|
||||||
-- platform.AgentBlock — Block name/ID
|
|
||||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- id TEXT Node execution UUID
|
|
||||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
|
||||||
-- agentNodeId TEXT Node UUID within the graph
|
|
||||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
|
||||||
-- addedTime TIMESTAMPTZ When the node was queued
|
|
||||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
|
||||||
-- startedTime TIMESTAMPTZ When execution started
|
|
||||||
-- endedTime TIMESTAMPTZ When execution finished
|
|
||||||
-- inputSize BIGINT Input payload size in bytes
|
|
||||||
-- outputSize BIGINT Output payload size in bytes
|
|
||||||
-- walltime NUMERIC Wall-clock seconds for this node
|
|
||||||
-- cputime NUMERIC CPU seconds for this node
|
|
||||||
-- llmRetryCount INT Number of LLM retries
|
|
||||||
-- llmCallCount INT Number of LLM API calls made
|
|
||||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
|
||||||
-- outputTokenCount BIGINT LLM output tokens produced
|
|
||||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
|
||||||
-- blockId TEXT Block UUID
|
|
||||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
|
||||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Most-used blocks by execution count
|
|
||||||
-- SELECT "blockName", COUNT(*) AS executions,
|
|
||||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
|
||||||
-- FROM analytics.node_block_execution
|
|
||||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
|
||||||
--
|
|
||||||
-- -- Average LLM token usage per block
|
|
||||||
-- SELECT "blockName",
|
|
||||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
|
||||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
|
||||||
-- FROM analytics.node_block_execution
|
|
||||||
-- WHERE "llmCallCount" > 0
|
|
||||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
|
||||||
--
|
|
||||||
-- -- Top failure reasons
|
|
||||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
|
||||||
-- FROM analytics.node_block_execution
|
|
||||||
-- WHERE "executionStatus" = 'FAILED'
|
|
||||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
ne."id" AS id,
|
|
||||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
|
||||||
ne."agentNodeId" AS agentNodeId,
|
|
||||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
|
||||||
ne."addedTime" AS addedTime,
|
|
||||||
ne."queuedTime" AS queuedTime,
|
|
||||||
ne."startedTime" AS startedTime,
|
|
||||||
ne."endedTime" AS endedTime,
|
|
||||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
|
||||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
|
||||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
|
||||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
|
||||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
|
||||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
|
||||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
|
||||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
|
||||||
b."name" AS blockName,
|
|
||||||
b."id" AS blockId,
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
TRIM(BOTH '"' FROM eio."data"::text),
|
|
||||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
|
||||||
'\1\2/...', 'gi'
|
|
||||||
),
|
|
||||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
|
||||||
) AS groupedErrorMessage,
|
|
||||||
eio."data" AS errorMessage
|
|
||||||
FROM platform."AgentNodeExecution" ne
|
|
||||||
LEFT JOIN platform."AgentNode" nd
|
|
||||||
ON ne."agentNodeId" = nd."id"
|
|
||||||
LEFT JOIN platform."AgentBlock" b
|
|
||||||
ON nd."agentBlockId" = b."id"
|
|
||||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
|
||||||
ON eio."referencedByOutputExecId" = ne."id"
|
|
||||||
AND eio."name" = 'error'
|
|
||||||
AND ne."executionStatus" = 'FAILED'
|
|
||||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_agent
|
|
||||||
-- Looker source alias: ds35 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention broken down per individual agent.
|
|
||||||
-- Cohort = week of a user's first use of THAT specific agent.
|
|
||||||
-- Tells you which agents keep users coming back vs. one-shot
|
|
||||||
-- use. Only includes cohorts from the last 180 days.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
|
||||||
-- platform.AgentGraph — Agent names
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- agent_id TEXT Agent graph UUID
|
|
||||||
-- agent_label TEXT 'AgentName [first8chars]'
|
|
||||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
|
||||||
-- cohort_week_start DATE Week users first ran this agent
|
|
||||||
-- cohort_label TEXT ISO week label
|
|
||||||
-- cohort_label_n TEXT ISO week label with cohort size
|
|
||||||
-- user_lifetime_week INT Weeks since first use of this agent
|
|
||||||
-- cohort_users BIGINT Users in this cohort for this agent
|
|
||||||
-- active_users BIGINT Users who ran the agent again in week k
|
|
||||||
-- retention_rate FLOAT active_users / cohort_users
|
|
||||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
|
||||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Best-retained agents at week 2
|
|
||||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
|
||||||
-- FROM analytics.retention_agent
|
|
||||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
|
||||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
|
||||||
--
|
|
||||||
-- -- Agents with most unique users
|
|
||||||
-- SELECT DISTINCT agent_label, agent_total_users
|
|
||||||
-- FROM analytics.retention_agent
|
|
||||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
|
||||||
e."createdAt"::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
|
||||||
FROM platform."AgentGraphExecution" e
|
|
||||||
),
|
|
||||||
first_use AS (
|
|
||||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1,2
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
|
||||||
),
|
|
||||||
active_counts AS (
|
|
||||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
|
||||||
),
|
|
||||||
cohort_sizes AS (
|
|
||||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
),
|
|
||||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
|
||||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
|
||||||
SELECT
|
|
||||||
g.agent_id,
|
|
||||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
|
||||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(ac.active_users,0) AS active_users,
|
|
||||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
|
||||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
|
||||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
|
||||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_execution_daily
|
|
||||||
-- Looker source alias: ds111 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Daily cohort retention based on agent executions.
|
|
||||||
-- Cohort anchor = day of user's FIRST ever execution.
|
|
||||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
|
||||||
-- Great for early engagement analysis (did users run another
|
|
||||||
-- agent the next day?).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- Same pattern as retention_login_daily.
|
|
||||||
-- cohort_day_start = day of first execution (not first login)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Day-3 execution retention
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
|
||||||
-- FROM analytics.retention_execution_daily
|
|
||||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
|
||||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
|
||||||
),
|
|
||||||
first_exec AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
|
||||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
|
||||||
user_day_age AS (
|
|
||||||
SELECT ad.user_id, fe.cohort_day_start,
|
|
||||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
|
||||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
|
||||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_day_start,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_day, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
|
||||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_execution_weekly
|
|
||||||
-- Looker source alias: ds92 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention based on agent executions.
|
|
||||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
|
||||||
-- (not first login). Only includes cohorts from the last 180 days.
|
|
||||||
-- Useful when you care about product engagement, not just visits.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- Same pattern as retention_login_weekly.
|
|
||||||
-- cohort_week_start = week of first execution (not first login)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Week-2 execution retention
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded
|
|
||||||
-- FROM analytics.retention_execution_weekly
|
|
||||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
|
||||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
|
||||||
),
|
|
||||||
first_exec AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, fe.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
|
||||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_login_daily
|
|
||||||
-- Looker source alias: ds112 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Daily cohort retention based on login sessions.
|
|
||||||
-- Same logic as retention_login_weekly but at day granularity,
|
|
||||||
-- showing up to day 30 for cohorts from the last 90 days.
|
|
||||||
-- Useful for analysing early activation (days 1-7) in detail.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login session records
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
|
||||||
-- cohort_day_start DATE First day the cohort logged in
|
|
||||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
|
||||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
|
||||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
|
||||||
-- cohort_users BIGINT Total users in cohort
|
|
||||||
-- active_users_bounded BIGINT Users active on exactly day k
|
|
||||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
|
||||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
|
||||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
|
||||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Day-1 retention rate (came back next day)
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
|
||||||
-- FROM analytics.retention_login_daily
|
|
||||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
|
||||||
--
|
|
||||||
-- -- Average retention curve across all cohorts
|
|
||||||
-- SELECT user_lifetime_day,
|
|
||||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
|
||||||
-- FROM analytics.retention_login_daily
|
|
||||||
-- GROUP BY 1 ORDER BY 1;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
|
||||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
|
||||||
),
|
|
||||||
first_login AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
|
||||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
|
||||||
user_day_age AS (
|
|
||||||
SELECT ad.user_id, fl.cohort_day_start,
|
|
||||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
|
||||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
|
||||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_day_start,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_day, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
|
||||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_login_onboarded_weekly
|
|
||||||
-- Looker source alias: ds101 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention from login sessions, restricted to
|
|
||||||
-- users who "onboarded" — defined as running at least one
|
|
||||||
-- agent within 365 days of their first login.
|
|
||||||
-- Filters out users who signed up but never activated,
|
|
||||||
-- giving a cleaner view of engaged-user retention.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login session records
|
|
||||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
|
||||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
|
||||||
-- Only difference: cohort is filtered to onboarded users only.
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Compare week-4 retention: all users vs onboarded only
|
|
||||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
|
||||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
|
||||||
-- UNION ALL
|
|
||||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
|
||||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
|
||||||
events AS (
|
|
||||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
|
||||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
|
||||||
),
|
|
||||||
first_login_all AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
),
|
|
||||||
onboarders AS (
|
|
||||||
SELECT fl.user_id FROM first_login_all fl
|
|
||||||
WHERE EXISTS (
|
|
||||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
|
||||||
WHERE e."userId"::text = fl.user_id
|
|
||||||
AND e."createdAt" >= fl.first_login_time
|
|
||||||
AND e."createdAt" < fl.first_login_time
|
|
||||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
|
||||||
)
|
|
||||||
),
|
|
||||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, fl.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
|
||||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_login_weekly
|
|
||||||
-- Looker source alias: ds83 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention based on login sessions.
|
|
||||||
-- Users are grouped by the ISO week of their first ever login.
|
|
||||||
-- For each cohort × lifetime-week combination, outputs both:
|
|
||||||
-- - bounded rate: % active in exactly that week
|
|
||||||
-- - unbounded rate: % who were ever active on or after that week
|
|
||||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login session records
|
|
||||||
--
|
|
||||||
-- HOW TO READ THE OUTPUT
|
|
||||||
-- cohort_week_start The Monday of the week users first logged in
|
|
||||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
|
||||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
|
||||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- cohort_week_start DATE First day of the cohort's signup week
|
|
||||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
|
||||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
|
||||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
|
||||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
|
||||||
-- active_users_bounded BIGINT Users active in exactly week k
|
|
||||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
|
||||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
|
||||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
|
||||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Week-1 retention rate per cohort
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
|
||||||
-- FROM analytics.retention_login_weekly
|
|
||||||
-- WHERE user_lifetime_week = 1
|
|
||||||
-- ORDER BY cohort_week_start;
|
|
||||||
--
|
|
||||||
-- -- Overall average retention curve (all cohorts combined)
|
|
||||||
-- SELECT user_lifetime_week,
|
|
||||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
|
||||||
-- FROM analytics.retention_login_weekly
|
|
||||||
-- GROUP BY 1 ORDER BY 1;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks),
|
|
||||||
events AS (
|
|
||||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
|
||||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
|
||||||
),
|
|
||||||
first_login AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, fl.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
|
||||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_block_spending
|
|
||||||
-- Looker source alias: ds6 | Charts: 5
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per credit transaction (last 90 days).
|
|
||||||
-- Shows how users spend credits broken down by block type,
|
|
||||||
-- LLM provider and model. Joins node execution stats for
|
|
||||||
-- token-level detail.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.CreditTransaction — Credit debit/credit records
|
|
||||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- transactionKey TEXT Unique transaction identifier
|
|
||||||
-- userId TEXT User who was charged
|
|
||||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
|
||||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
|
||||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
|
||||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
|
||||||
-- blockId TEXT Block UUID that triggered the spend
|
|
||||||
-- blockName TEXT Human-readable block name
|
|
||||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
|
||||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
|
||||||
-- node_exec_id TEXT Linked node execution UUID
|
|
||||||
-- llm_call_count INT LLM API calls made in that execution
|
|
||||||
-- llm_retry_count INT LLM retries in that execution
|
|
||||||
-- llm_input_token_count INT Input tokens consumed
|
|
||||||
-- llm_output_token_count INT Output tokens produced
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Total spend per user (last 90 days)
|
|
||||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
|
||||||
-- FROM analytics.user_block_spending
|
|
||||||
-- WHERE "transactionType" = 'USAGE'
|
|
||||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
|
||||||
--
|
|
||||||
-- -- Spend by LLM provider + model
|
|
||||||
-- SELECT "llm_provider", "llm_model",
|
|
||||||
-- SUM("negativeAmount") AS total_cost,
|
|
||||||
-- SUM("llm_input_token_count") AS input_tokens,
|
|
||||||
-- SUM("llm_output_token_count") AS output_tokens
|
|
||||||
-- FROM analytics.user_block_spending
|
|
||||||
-- WHERE "llm_provider" IS NOT NULL
|
|
||||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
c."transactionKey" AS transactionKey,
|
|
||||||
c."userId" AS userId,
|
|
||||||
c."amount" AS amount,
|
|
||||||
c."amount" * -1 AS negativeAmount,
|
|
||||||
c."type" AS transactionType,
|
|
||||||
c."createdAt" AS transactionTime,
|
|
||||||
c.metadata->>'block_id' AS blockId,
|
|
||||||
c.metadata->>'block' AS blockName,
|
|
||||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
|
||||||
c.metadata->'input'->>'model' AS llm_model,
|
|
||||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
|
||||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
|
||||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
|
||||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
|
||||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
|
||||||
FROM platform."CreditTransaction" c
|
|
||||||
LEFT JOIN platform."AgentNodeExecution" ne
|
|
||||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
|
||||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_onboarding
|
|
||||||
-- Looker source alias: ds68 | Charts: 3
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per user onboarding record. Contains the user's
|
|
||||||
-- stated usage reason, selected integrations, completed
|
|
||||||
-- onboarding steps and optional first agent selection.
|
|
||||||
-- Full history (no date filter) since onboarding happens
|
|
||||||
-- once per user.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.UserOnboarding — Onboarding state per user
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- id TEXT Onboarding record UUID
|
|
||||||
-- createdAt TIMESTAMPTZ When onboarding started
|
|
||||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
|
||||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
|
||||||
-- integrations TEXT[] Array of integration names the user selected
|
|
||||||
-- userId TEXT User UUID
|
|
||||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
|
||||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Usage reason breakdown
|
|
||||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
|
||||||
--
|
|
||||||
-- -- Completion rate per step
|
|
||||||
-- SELECT step, COUNT(*) AS users_completed
|
|
||||||
-- FROM analytics.user_onboarding
|
|
||||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
|
||||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
id,
|
|
||||||
"createdAt",
|
|
||||||
"updatedAt",
|
|
||||||
"usageReason",
|
|
||||||
integrations,
|
|
||||||
"userId",
|
|
||||||
"completedSteps",
|
|
||||||
"selectedStoreListingVersionId"
|
|
||||||
FROM platform."UserOnboarding"
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_onboarding_funnel
|
|
||||||
-- Looker source alias: ds74 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Pre-aggregated onboarding funnel showing how many users
|
|
||||||
-- completed each step and the drop-off percentage from the
|
|
||||||
-- previous step. One row per onboarding step (all 22 steps
|
|
||||||
-- always present, even with 0 completions — prevents sparse
|
|
||||||
-- gaps from making LAG compare the wrong predecessors).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
|
||||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
|
||||||
-- users_completed BIGINT Distinct users who completed this step
|
|
||||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
|
||||||
--
|
|
||||||
-- STEP ORDER
|
|
||||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
|
||||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
|
||||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
|
||||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
|
||||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
|
||||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
|
||||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
|
||||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Users who started onboarding in the last 90 days
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Full funnel
|
|
||||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
|
||||||
--
|
|
||||||
-- -- Biggest drop-off point
|
|
||||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
|
||||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH all_steps AS (
|
|
||||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
|
||||||
-- are always present, keeping LAG comparisons correct.
|
|
||||||
SELECT step_name, step_order
|
|
||||||
FROM (VALUES
|
|
||||||
('WELCOME', 1),
|
|
||||||
('USAGE_REASON', 2),
|
|
||||||
('INTEGRATIONS', 3),
|
|
||||||
('AGENT_CHOICE', 4),
|
|
||||||
('AGENT_NEW_RUN', 5),
|
|
||||||
('AGENT_INPUT', 6),
|
|
||||||
('CONGRATS', 7),
|
|
||||||
('GET_RESULTS', 8),
|
|
||||||
('MARKETPLACE_VISIT', 9),
|
|
||||||
('MARKETPLACE_ADD_AGENT', 10),
|
|
||||||
('MARKETPLACE_RUN_AGENT', 11),
|
|
||||||
('BUILDER_OPEN', 12),
|
|
||||||
('BUILDER_SAVE_AGENT', 13),
|
|
||||||
('BUILDER_RUN_AGENT', 14),
|
|
||||||
('VISIT_COPILOT', 15),
|
|
||||||
('RE_RUN_AGENT', 16),
|
|
||||||
('SCHEDULE_AGENT', 17),
|
|
||||||
('RUN_AGENTS', 18),
|
|
||||||
('RUN_3_DAYS', 19),
|
|
||||||
('TRIGGER_WEBHOOK', 20),
|
|
||||||
('RUN_14_DAYS', 21),
|
|
||||||
('RUN_AGENTS_100', 22)
|
|
||||||
) AS t(step_name, step_order)
|
|
||||||
),
|
|
||||||
raw AS (
|
|
||||||
SELECT
|
|
||||||
u."userId",
|
|
||||||
step_txt::text AS step
|
|
||||||
FROM platform."UserOnboarding" u
|
|
||||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
|
||||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
),
|
|
||||||
step_counts AS (
|
|
||||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
|
||||||
FROM raw GROUP BY step
|
|
||||||
),
|
|
||||||
funnel AS (
|
|
||||||
SELECT
|
|
||||||
a.step_name AS step,
|
|
||||||
a.step_order,
|
|
||||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
|
||||||
ROUND(
|
|
||||||
100.0 * COALESCE(sc.users_completed, 0)
|
|
||||||
/ NULLIF(
|
|
||||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
|
||||||
0
|
|
||||||
),
|
|
||||||
2
|
|
||||||
) AS pct_from_prev
|
|
||||||
FROM all_steps a
|
|
||||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
|
||||||
)
|
|
||||||
SELECT * FROM funnel ORDER BY step_order
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_onboarding_integration
|
|
||||||
-- Looker source alias: ds75 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Pre-aggregated count of users who selected each integration
|
|
||||||
-- during onboarding. One row per integration type, sorted
|
|
||||||
-- by popularity.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.UserOnboarding — integrations array column
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
|
||||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Users who started onboarding in the last 90 days
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Full integration popularity ranking
|
|
||||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
|
||||||
--
|
|
||||||
-- -- Top 5 integrations
|
|
||||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH exploded AS (
|
|
||||||
SELECT
|
|
||||||
u."userId" AS user_id,
|
|
||||||
UNNEST(u."integrations") AS integration
|
|
||||||
FROM platform."UserOnboarding" u
|
|
||||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
integration,
|
|
||||||
COUNT(DISTINCT user_id) AS users_with_integration
|
|
||||||
FROM exploded
|
|
||||||
WHERE integration IS NOT NULL AND integration <> ''
|
|
||||||
GROUP BY integration
|
|
||||||
ORDER BY users_with_integration DESC
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.users_activities
|
|
||||||
-- Looker source alias: ds56 | Charts: 5
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per user with lifetime activity summary.
|
|
||||||
-- Joins login sessions with agent graphs, executions and
|
|
||||||
-- node-level runs to give a full picture of how engaged
|
|
||||||
-- each user is. Includes a convenience flag for 7-day
|
|
||||||
-- activation (did the user return at least 7 days after
|
|
||||||
-- their first login?).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login/session records
|
|
||||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
|
||||||
-- platform.AgentGraphExecution — Agent run history
|
|
||||||
-- platform.AgentNodeExecution — Individual block execution history
|
|
||||||
--
|
|
||||||
-- PERFORMANCE NOTE
|
|
||||||
-- Each CTE aggregates its own table independently by userId.
|
|
||||||
-- This avoids the fan-out that occurs when driving every join
|
|
||||||
-- from user_logins across the two largest tables
|
|
||||||
-- (AgentGraphExecution and AgentNodeExecution).
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- user_id TEXT Supabase user UUID
|
|
||||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
|
||||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
|
||||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
|
||||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
|
||||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
|
||||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
|
||||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
|
||||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
|
||||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
|
||||||
-- node_execution_count BIGINT Total node executions across all runs
|
|
||||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
|
||||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
|
||||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
|
||||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
|
||||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
|
||||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
|
||||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
|
||||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Users who ran at least one agent and returned after 7 days
|
|
||||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
|
||||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
|
||||||
--
|
|
||||||
-- -- Top 10 most active users by agent runs
|
|
||||||
-- SELECT user_id, agent_runs, node_execution_count
|
|
||||||
-- FROM analytics.users_activities
|
|
||||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
|
||||||
--
|
|
||||||
-- -- 7-day activation rate
|
|
||||||
-- SELECT
|
|
||||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
|
||||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
|
||||||
-- AS activation_rate
|
|
||||||
-- FROM analytics.users_activities;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH user_logins AS (
|
|
||||||
SELECT
|
|
||||||
user_id::text AS user_id,
|
|
||||||
MIN(created_at) AS first_login_time,
|
|
||||||
MAX(created_at) AS last_login_time,
|
|
||||||
GREATEST(
|
|
||||||
MAX(refreshed_at)::timestamptz,
|
|
||||||
MAX(created_at)::timestamptz
|
|
||||||
) AS last_visit_time
|
|
||||||
FROM auth.sessions
|
|
||||||
GROUP BY user_id
|
|
||||||
),
|
|
||||||
user_agents AS (
|
|
||||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
|
||||||
SELECT
|
|
||||||
"userId"::text AS user_id,
|
|
||||||
MAX("updatedAt") AS last_agent_save_time,
|
|
||||||
COUNT(DISTINCT "id") AS agent_count
|
|
||||||
FROM platform."AgentGraph"
|
|
||||||
WHERE "isActive"
|
|
||||||
GROUP BY "userId"
|
|
||||||
),
|
|
||||||
user_graph_runs AS (
|
|
||||||
-- Aggregate AgentGraphExecution directly by userId
|
|
||||||
SELECT
|
|
||||||
"userId"::text AS user_id,
|
|
||||||
MIN("createdAt") AS first_agent_run_time,
|
|
||||||
MAX("createdAt") AS last_agent_run_time,
|
|
||||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
|
||||||
COUNT("id") AS agent_runs
|
|
||||||
FROM platform."AgentGraphExecution"
|
|
||||||
GROUP BY "userId"
|
|
||||||
),
|
|
||||||
user_node_runs AS (
|
|
||||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
|
||||||
-- single join to AgentGraphExecution instead of fanning out from
|
|
||||||
-- user_logins through both large tables.
|
|
||||||
SELECT
|
|
||||||
g."userId"::text AS user_id,
|
|
||||||
COUNT(*) AS node_execution_count,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
|
||||||
FROM platform."AgentNodeExecution" n
|
|
||||||
JOIN platform."AgentGraphExecution" g
|
|
||||||
ON g."id" = n."agentGraphExecutionId"
|
|
||||||
GROUP BY g."userId"
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
ul.user_id,
|
|
||||||
ul.first_login_time,
|
|
||||||
ul.last_login_time,
|
|
||||||
ul.last_visit_time,
|
|
||||||
ua.last_agent_save_time,
|
|
||||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
|
||||||
gr.first_agent_run_time,
|
|
||||||
gr.last_agent_run_time,
|
|
||||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
|
||||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
|
||||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
|
||||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
|
||||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
|
||||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
|
||||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
|
||||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
|
||||||
CASE
|
|
||||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
|
||||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
|
||||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
|
||||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
|
||||||
ELSE NULL
|
|
||||||
END AS is_active_after_7d,
|
|
||||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
|
||||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
|
||||||
FROM user_logins ul
|
|
||||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
|
||||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
|
||||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
|
||||||
1857
autogpt_platform/autogpt_libs/poetry.lock
generated
1857
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.7"
|
fastapi = "^0.116.1"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.12.1"
|
||||||
launchdarkly-server-sdk = "^9.15.0"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.11.7"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.28.0"
|
supabase = "^2.16.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.408"
|
pyright = "^1.1.404"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.3.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.14.1"
|
||||||
pytest-cov = "^7.0.0"
|
pytest-cov = "^6.2.1"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.12.11"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
|||||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||||
|
|
||||||
## ===== SIGNUP / INVITE GATE ===== ##
|
|
||||||
# Set to true to require an invite before users can sign up
|
|
||||||
ENABLE_INVITE_GATE=false
|
|
||||||
|
|
||||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||||
# Platform URLs (set these for webhooks and OAuth to work)
|
# Platform URLs (set these for webhooks and OAuth to work)
|
||||||
PLATFORM_BASE_URL=http://localhost:8000
|
PLATFORM_BASE_URL=http://localhost:8000
|
||||||
@@ -108,12 +104,6 @@ TWITTER_CLIENT_SECRET=
|
|||||||
# Make a new workspace for your OAuth APP -- trust me
|
# Make a new workspace for your OAuth APP -- trust me
|
||||||
# https://linear.app/settings/api/applications/new
|
# https://linear.app/settings/api/applications/new
|
||||||
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||||
LINEAR_API_KEY=
|
|
||||||
# Linear project and team IDs for the feature request tracker.
|
|
||||||
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
|
|
||||||
# and in team settings. Used by the chat copilot to file and search feature requests.
|
|
||||||
LINEAR_FEATURE_REQUEST_PROJECT_ID=
|
|
||||||
LINEAR_FEATURE_REQUEST_TEAM_ID=
|
|
||||||
LINEAR_CLIENT_ID=
|
LINEAR_CLIENT_ID=
|
||||||
LINEAR_CLIENT_SECRET=
|
LINEAR_CLIENT_SECRET=
|
||||||
|
|
||||||
@@ -162,7 +152,6 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
ELEVENLABS_API_KEY=
|
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
@@ -194,8 +183,5 @@ ZEROBOUNCE_API_KEY=
|
|||||||
POSTHOG_API_KEY=
|
POSTHOG_API_KEY=
|
||||||
POSTHOG_HOST=https://eu.i.posthog.com
|
POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
# Tally Form Integration (pre-populate business understanding on signup)
|
|
||||||
TALLY_API_KEY=
|
|
||||||
|
|
||||||
# Other Services
|
# Other Services
|
||||||
AUTOMOD_API_KEY=
|
AUTOMOD_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,6 +19,3 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
# Workspace files
|
|
||||||
workspaces/
|
|
||||||
|
|||||||
@@ -58,31 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
|
|||||||
- **Authentication**: JWT-based with Supabase integration
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
## Code Style
|
|
||||||
|
|
||||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
|
||||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
|
||||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
|
||||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
|
||||||
- **List comprehensions** over manual loop-and-append
|
|
||||||
- **Early return** — guard clauses first, avoid deep nesting
|
|
||||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
|
||||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
|
||||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
|
||||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
|
||||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
|
||||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
|
||||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
|
||||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
|
||||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
|
||||||
|
|
||||||
## Testing Approach
|
## Testing Approach
|
||||||
|
|
||||||
- Uses pytest with snapshot testing for API responses
|
- Uses pytest with snapshot testing for API responses
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
|
||||||
- After refactoring, update mock targets to match new module paths
|
|
||||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
|
||||||
|
|
||||||
## Database Schema
|
## Database Schema
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
# ============================ DEPENDENCY BUILDER ============================ #
|
|
||||||
|
|
||||||
FROM debian:13-slim AS builder
|
FROM debian:13-slim AS builder
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
@@ -39,13 +37,15 @@ ENV POETRY_VIRTUALENVS_CREATE=true
|
|||||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
RUN pip3 install poetry --break-system-packages
|
RUN pip3 install --no-cache-dir poetry --break-system-packages
|
||||||
|
|
||||||
# Copy and install dependencies
|
# Copy and install dependencies
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||||
WORKDIR /app/autogpt_platform/backend
|
WORKDIR /app/autogpt_platform/backend
|
||||||
RUN poetry install --no-ansi --no-root
|
# Production image only needs runtime deps; dev deps (pytest, black, ruff, etc.)
|
||||||
|
# are installed locally via `poetry install --with dev` per the development docs
|
||||||
|
RUN poetry install --no-ansi --no-root --only main
|
||||||
|
|
||||||
# Generate Prisma client
|
# Generate Prisma client
|
||||||
COPY autogpt_platform/backend/schema.prisma ./
|
COPY autogpt_platform/backend/schema.prisma ./
|
||||||
@@ -53,106 +53,65 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# =============================== DB MIGRATOR =============================== #
|
# Clean up build artifacts and caches to reduce layer size
|
||||||
|
# Note: setuptools is kept as it's a direct dependency (used by aioclamd via pkg_resources)
|
||||||
|
RUN find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true; \
|
||||||
|
find /app -type d -name tests -exec rm -rf {} + 2>/dev/null || true; \
|
||||||
|
find /app -type d -name test -exec rm -rf {} + 2>/dev/null || true; \
|
||||||
|
rm -rf /app/autogpt_platform/backend/.venv/lib/python*/site-packages/pip* \
|
||||||
|
/root/.cache/pip \
|
||||||
|
/root/.cache/pypoetry
|
||||||
|
|
||||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
FROM debian:13-slim AS server_dependencies
|
||||||
FROM debian:13-slim AS migrate
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3.13 \
|
|
||||||
python3-pip \
|
|
||||||
ca-certificates \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy Node.js from builder (needed for Prisma CLI)
|
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
|
||||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
|
||||||
|
|
||||||
# Copy Prisma binaries
|
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
|
||||||
|
|
||||||
# Install prisma-client-py directly (much smaller than copying full venv)
|
|
||||||
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
|
||||||
|
|
||||||
COPY autogpt_platform/backend/schema.prisma ./
|
|
||||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
|
||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
|
||||||
COPY autogpt_platform/backend/migrations ./migrations
|
|
||||||
|
|
||||||
# ============================== BACKEND SERVER ============================== #
|
|
||||||
|
|
||||||
FROM debian:13-slim AS server
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV POETRY_HOME=/opt/poetry \
|
||||||
|
POETRY_NO_INTERACTION=1 \
|
||||||
|
POETRY_VIRTUALENVS_CREATE=true \
|
||||||
|
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||||
|
DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
# Install Python without upgrading system-managed packages
|
||||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
RUN apt-get update && apt-get install -y \
|
||||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
|
||||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
|
||||||
imagemagick \
|
|
||||||
jq \
|
|
||||||
ripgrep \
|
|
||||||
tree \
|
|
||||||
bubblewrap \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
# Copy built artifacts from builder (cleaned of caches, __pycache__, and test dirs)
|
||||||
|
COPY --from=builder /app /app
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
# Copy Node.js installation for Prisma and agent-browser.
|
# Copy Node.js installation for Prisma
|
||||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
|
||||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
|
||||||
# proper symlinks so npm/npx can find their modules.
|
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
# Copy fresh source from context (overwrites builder's copy with latest source)
|
||||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
|
||||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
|
||||||
fonts-liberation libfontconfig1 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& npm install -g agent-browser \
|
|
||||||
&& agent-browser install \
|
|
||||||
&& rm -rf /tmp/* /root/.npm
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
# Copy only the .venv from builder (not the entire /app directory)
|
FROM server_dependencies AS migrate
|
||||||
# The .venv includes the generated Prisma client
|
|
||||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
|
||||||
|
|
||||||
# Copy dependency files + autogpt_libs (path dependency)
|
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||||
|
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||||
|
|
||||||
# Copy backend code + docs (for Copilot docs search)
|
FROM server_dependencies AS server
|
||||||
COPY autogpt_platform/backend ./
|
|
||||||
|
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||||
COPY docs /app/docs
|
COPY docs /app/docs
|
||||||
# Install the project package to create entry point scripts in .venv/bin/
|
RUN poetry install --no-ansi --only-root
|
||||||
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
|
||||||
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
|
||||||
poetry install --no-ansi --only-root
|
|
||||||
|
|
||||||
ENV PORT=8000
|
ENV PORT=8000
|
||||||
|
|
||||||
CMD ["rest"]
|
CMD ["poetry", "run", "rest"]
|
||||||
|
|||||||
@@ -1,9 +1,4 @@
|
|||||||
"""Common test fixtures for server tests.
|
"""Common test fixtures for server tests."""
|
||||||
|
|
||||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
|
||||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
|
||||||
(backend/conftest.py) and are available here automatically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
|||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_user_id() -> str:
|
||||||
|
"""Test user ID fixture."""
|
||||||
|
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user_id() -> str:
|
||||||
|
"""Admin user ID fixture."""
|
||||||
|
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_user_id() -> str:
|
||||||
|
"""Target user ID fixture."""
|
||||||
|
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_test_user(test_user_id):
|
||||||
|
"""Create test user in database before tests."""
|
||||||
|
from backend.data.user import get_or_create_user
|
||||||
|
|
||||||
|
# Create the test user in the database using JWT token format
|
||||||
|
user_data = {
|
||||||
|
"sub": test_user_id,
|
||||||
|
"email": "test@example.com",
|
||||||
|
"user_metadata": {"name": "Test User"},
|
||||||
|
}
|
||||||
|
await get_or_create_user(user_data)
|
||||||
|
return test_user_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_admin_user(admin_user_id):
|
||||||
|
"""Create admin user in database before tests."""
|
||||||
|
from backend.data.user import get_or_create_user
|
||||||
|
|
||||||
|
# Create the admin user in the database using JWT token format
|
||||||
|
user_data = {
|
||||||
|
"sub": admin_user_id,
|
||||||
|
"email": "test-admin@example.com",
|
||||||
|
"user_metadata": {"name": "Test Admin"},
|
||||||
|
}
|
||||||
|
await get_or_create_user(user_data)
|
||||||
|
return admin_user_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_jwt_user(test_user_id):
|
def mock_jwt_user(test_user_id):
|
||||||
"""Provide mock JWT payload for regular user testing."""
|
"""Provide mock JWT payload for regular user testing."""
|
||||||
|
|||||||
@@ -88,23 +88,20 @@ async def require_auth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_permission(*permissions: APIKeyPermission):
|
def require_permission(permission: APIKeyPermission):
|
||||||
"""
|
"""
|
||||||
Dependency function for checking required permissions.
|
Dependency function for checking specific permissions
|
||||||
All listed permissions must be present.
|
|
||||||
(works with API keys and OAuth tokens)
|
(works with API keys and OAuth tokens)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def check_permissions(
|
async def check_permission(
|
||||||
auth: APIAuthorizationInfo = Security(require_auth),
|
auth: APIAuthorizationInfo = Security(require_auth),
|
||||||
) -> APIAuthorizationInfo:
|
) -> APIAuthorizationInfo:
|
||||||
missing = [p for p in permissions if p not in auth.scopes]
|
if permission not in auth.scopes:
|
||||||
if missing:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Missing required permission(s): "
|
detail=f"Missing required permission: {permission.value}",
|
||||||
f"{', '.join(p.value for p in missing)}",
|
|
||||||
)
|
)
|
||||||
return auth
|
return auth
|
||||||
|
|
||||||
return check_permissions
|
return check_permission
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Annotated, Any, Optional, Sequence
|
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, HTTPException, Security
|
from fastapi import APIRouter, Body, HTTPException, Security
|
||||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||||
@@ -9,17 +9,15 @@ from pydantic import BaseModel, Field
|
|||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.db as store_db
|
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.blocks
|
import backend.data.block
|
||||||
from backend.api.external.middleware import require_auth, require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .integrations import integrations_router
|
from .integrations import integrations_router
|
||||||
@@ -69,7 +67,7 @@ async def get_user_info(
|
|||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
)
|
)
|
||||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||||
return [b.to_dict() for b in blocks if not b.disabled]
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
@@ -85,7 +83,7 @@ async def execute_graph_block(
|
|||||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||||
),
|
),
|
||||||
) -> CompletedBlockOutput:
|
) -> CompletedBlockOutput:
|
||||||
obj = backend.blocks.get_block(block_id)
|
obj = backend.data.block.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
if obj.disabled:
|
if obj.disabled:
|
||||||
@@ -97,43 +95,6 @@ async def execute_graph_block(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
|
||||||
path="/graphs",
|
|
||||||
tags=["graphs"],
|
|
||||||
status_code=201,
|
|
||||||
dependencies=[
|
|
||||||
Security(
|
|
||||||
require_permission(
|
|
||||||
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def create_graph(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
auth: APIAuthorizationInfo = Security(
|
|
||||||
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
|
|
||||||
),
|
|
||||||
) -> graph_db.GraphModel:
|
|
||||||
"""
|
|
||||||
Create a new agent graph.
|
|
||||||
|
|
||||||
The graph will be validated and assigned a new ID.
|
|
||||||
It is automatically added to the user's library.
|
|
||||||
"""
|
|
||||||
from backend.api.features.library import db as library_db
|
|
||||||
|
|
||||||
graph_model = graph_db.make_graph_model(graph, auth.user_id)
|
|
||||||
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
|
||||||
graph_model.validate_graph(for_run=False)
|
|
||||||
|
|
||||||
await graph_db.create_graph(graph_model, user_id=auth.user_id)
|
|
||||||
await library_db.create_library_agent(graph_model, auth.user_id)
|
|
||||||
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
|
|
||||||
|
|
||||||
return activated_graph
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
@@ -231,13 +192,13 @@ async def get_graph_execution_results(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/agents",
|
path="/store/agents",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.StoreAgentsResponse,
|
response_model=store_model.StoreAgentsResponse,
|
||||||
)
|
)
|
||||||
async def get_store_agents(
|
async def get_store_agents(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
creator: str | None = None,
|
creator: str | None = None,
|
||||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@@ -279,7 +240,7 @@ async def get_store_agents(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/agents/{username}/{agent_name}",
|
path="/store/agents/{username}/{agent_name}",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.StoreAgentDetails,
|
response_model=store_model.StoreAgentDetails,
|
||||||
)
|
)
|
||||||
async def get_store_agent(
|
async def get_store_agent(
|
||||||
@@ -307,13 +268,13 @@ async def get_store_agent(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/creators",
|
path="/store/creators",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.CreatorsResponse,
|
response_model=store_model.CreatorsResponse,
|
||||||
)
|
)
|
||||||
async def get_store_creators(
|
async def get_store_creators(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.CreatorsResponse:
|
) -> store_model.CreatorsResponse:
|
||||||
@@ -349,7 +310,7 @@ async def get_store_creators(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/creators/{username}",
|
path="/store/creators/{username}",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.CreatorDetails,
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
async def get_store_creator(
|
async def get_store_creator(
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||||
from backend.copilot.tools.models import ToolResponseBase
|
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,17 +1,8 @@
|
|||||||
from __future__ import annotations
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
|
||||||
|
|
||||||
import prisma.enums
|
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
|
|
||||||
from backend.data.model import UserTransaction
|
from backend.data.model import UserTransaction
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
|
||||||
|
|
||||||
|
|
||||||
class UserHistoryResponse(BaseModel):
|
class UserHistoryResponse(BaseModel):
|
||||||
"""Response model for listings with version history"""
|
"""Response model for listings with version history"""
|
||||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
|||||||
class AddUserCreditsResponse(BaseModel):
|
class AddUserCreditsResponse(BaseModel):
|
||||||
new_balance: int
|
new_balance: int
|
||||||
transaction_key: str
|
transaction_key: str
|
||||||
|
|
||||||
|
|
||||||
class CreateInvitedUserRequest(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InvitedUserResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
email: str
|
|
||||||
status: prisma.enums.InvitedUserStatus
|
|
||||||
auth_user_id: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
tally_understanding: Optional[dict[str, Any]] = None
|
|
||||||
tally_status: prisma.enums.TallyComputationStatus
|
|
||||||
tally_computed_at: Optional[datetime] = None
|
|
||||||
tally_error: Optional[str] = None
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
|
||||||
return cls.model_validate(record.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
class InvitedUsersResponse(BaseModel):
|
|
||||||
invited_users: list[InvitedUserResponse]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
class BulkInvitedUserRowResponse(BaseModel):
|
|
||||||
row_number: int
|
|
||||||
email: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
|
||||||
message: str
|
|
||||||
invited_user: Optional[InvitedUserResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BulkInvitedUsersResponse(BaseModel):
|
|
||||||
created_count: int
|
|
||||||
skipped_count: int
|
|
||||||
error_count: int
|
|
||||||
results: list[BulkInvitedUserRowResponse]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
|
||||||
return cls(
|
|
||||||
created_count=result.created_count,
|
|
||||||
skipped_count=result.skipped_count,
|
|
||||||
error_count=result.error_count,
|
|
||||||
results=[
|
|
||||||
BulkInvitedUserRowResponse(
|
|
||||||
row_number=row.row_number,
|
|
||||||
email=row.email,
|
|
||||||
name=row.name,
|
|
||||||
status=row.status,
|
|
||||||
message=row.message,
|
|
||||||
invited_user=(
|
|
||||||
InvitedUserResponse.from_record(row.invited_user)
|
|
||||||
if row.invited_user is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for row in result.results
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/listings",
|
"/listings",
|
||||||
summary="Get Admin Listings History",
|
summary="Get Admin Listings History",
|
||||||
|
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||||
)
|
)
|
||||||
async def get_admin_listings_with_versions(
|
async def get_admin_listings_with_versions(
|
||||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||||
search: typing.Optional[str] = None,
|
search: typing.Optional[str] = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
):
|
||||||
"""
|
"""
|
||||||
Get store listings with their version history for admins.
|
Get store listings with their version history for admins.
|
||||||
|
|
||||||
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
|
|||||||
page_size: Number of items per page
|
page_size: Number of items per page
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Paginated listings with their versions
|
StoreListingsWithVersionsResponse with listings and their versions
|
||||||
"""
|
"""
|
||||||
listings = await store_db.get_admin_listings_with_versions(
|
try:
|
||||||
status=status,
|
listings = await store_db.get_admin_listings_with_versions(
|
||||||
search_query=search,
|
status=status,
|
||||||
page=page,
|
search_query=search,
|
||||||
page_size=page_size,
|
page=page,
|
||||||
)
|
page_size=page_size,
|
||||||
return listings
|
)
|
||||||
|
return listings
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error getting admin listings with versions: %s", e)
|
||||||
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"detail": "An error occurred while retrieving listings with versions"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/submissions/{store_listing_version_id}/review",
|
"/submissions/{store_listing_version_id}/review",
|
||||||
summary="Review Store Submission",
|
summary="Review Store Submission",
|
||||||
|
response_model=store_model.StoreSubmission,
|
||||||
)
|
)
|
||||||
async def review_submission(
|
async def review_submission(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
request: store_model.ReviewSubmissionRequest,
|
request: store_model.ReviewSubmissionRequest,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.StoreSubmissionAdminView:
|
):
|
||||||
"""
|
"""
|
||||||
Review a store listing submission.
|
Review a store listing submission.
|
||||||
|
|
||||||
@@ -73,24 +84,31 @@ async def review_submission(
|
|||||||
user_id: Authenticated admin user performing the review
|
user_id: Authenticated admin user performing the review
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StoreSubmissionAdminView with updated review information
|
StoreSubmission with updated review information
|
||||||
"""
|
"""
|
||||||
already_approved = await store_db.check_submission_already_approved(
|
try:
|
||||||
store_listing_version_id=store_listing_version_id,
|
already_approved = await store_db.check_submission_already_approved(
|
||||||
)
|
store_listing_version_id=store_listing_version_id,
|
||||||
submission = await store_db.review_store_submission(
|
)
|
||||||
store_listing_version_id=store_listing_version_id,
|
submission = await store_db.review_store_submission(
|
||||||
is_approved=request.is_approved,
|
store_listing_version_id=store_listing_version_id,
|
||||||
external_comments=request.comments,
|
is_approved=request.is_approved,
|
||||||
internal_comments=request.internal_comments or "",
|
external_comments=request.comments,
|
||||||
reviewer_id=user_id,
|
internal_comments=request.internal_comments or "",
|
||||||
)
|
reviewer_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
state_changed = already_approved != request.is_approved
|
state_changed = already_approved != request.is_approved
|
||||||
# Clear caches whenever approval state changes, since store visibility can change
|
# Clear caches when the request is approved as it updates what is shown on the store
|
||||||
if state_changed:
|
if state_changed:
|
||||||
store_cache.clear_all_caches()
|
store_cache.clear_all_caches()
|
||||||
return submission
|
return submission
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error reviewing submission: %s", e)
|
||||||
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"detail": "An error occurred while reviewing the submission"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
import logging
|
|
||||||
import math
|
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
|
||||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
|
||||||
|
|
||||||
from backend.data.invited_user import (
|
|
||||||
bulk_create_invited_users_from_file,
|
|
||||||
create_invited_user,
|
|
||||||
list_invited_users,
|
|
||||||
retry_invited_user_tally,
|
|
||||||
revoke_invited_user,
|
|
||||||
)
|
|
||||||
from backend.data.tally import mask_email
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
from .model import (
|
|
||||||
BulkInvitedUsersResponse,
|
|
||||||
CreateInvitedUserRequest,
|
|
||||||
InvitedUserResponse,
|
|
||||||
InvitedUsersResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/admin",
|
|
||||||
tags=["users", "admin"],
|
|
||||||
dependencies=[Security(requires_admin_user)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/invited-users",
|
|
||||||
response_model=InvitedUsersResponse,
|
|
||||||
summary="List Invited Users",
|
|
||||||
)
|
|
||||||
async def get_invited_users(
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
page_size: int = Query(50, ge=1, le=200),
|
|
||||||
) -> InvitedUsersResponse:
|
|
||||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
|
||||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
|
||||||
return InvitedUsersResponse(
|
|
||||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=total,
|
|
||||||
total_pages=max(1, math.ceil(total / page_size)),
|
|
||||||
current_page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users",
|
|
||||||
response_model=InvitedUserResponse,
|
|
||||||
summary="Create Invited User",
|
|
||||||
)
|
|
||||||
async def create_invited_user_route(
|
|
||||||
request: CreateInvitedUserRequest,
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> InvitedUserResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s creating invited user for %s",
|
|
||||||
admin_user_id,
|
|
||||||
mask_email(request.email),
|
|
||||||
)
|
|
||||||
invited_user = await create_invited_user(request.email, request.name)
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s created invited user %s",
|
|
||||||
admin_user_id,
|
|
||||||
invited_user.id,
|
|
||||||
)
|
|
||||||
return InvitedUserResponse.from_record(invited_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users/bulk",
|
|
||||||
response_model=BulkInvitedUsersResponse,
|
|
||||||
summary="Bulk Create Invited Users",
|
|
||||||
operation_id="postV2BulkCreateInvitedUsers",
|
|
||||||
)
|
|
||||||
async def bulk_create_invited_users_route(
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> BulkInvitedUsersResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s bulk invited users from %s",
|
|
||||||
admin_user_id,
|
|
||||||
file.filename or "<unnamed>",
|
|
||||||
)
|
|
||||||
content = await file.read()
|
|
||||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
|
||||||
return BulkInvitedUsersResponse.from_result(result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users/{invited_user_id}/revoke",
|
|
||||||
response_model=InvitedUserResponse,
|
|
||||||
summary="Revoke Invited User",
|
|
||||||
)
|
|
||||||
async def revoke_invited_user_route(
|
|
||||||
invited_user_id: str,
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> InvitedUserResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
|
||||||
)
|
|
||||||
invited_user = await revoke_invited_user(invited_user_id)
|
|
||||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
|
||||||
return InvitedUserResponse.from_record(invited_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users/{invited_user_id}/retry-tally",
|
|
||||||
response_model=InvitedUserResponse,
|
|
||||||
summary="Retry Invited User Tally",
|
|
||||||
)
|
|
||||||
async def retry_invited_user_tally_route(
|
|
||||||
invited_user_id: str,
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> InvitedUserResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s retrying Tally seed for invited user %s",
|
|
||||||
admin_user_id,
|
|
||||||
invited_user_id,
|
|
||||||
)
|
|
||||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s retried Tally seed for invited user %s",
|
|
||||||
admin_user_id,
|
|
||||||
invited_user_id,
|
|
||||||
)
|
|
||||||
return InvitedUserResponse.from_record(invited_user)
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
import prisma.enums
|
|
||||||
import pytest
|
|
||||||
import pytest_mock
|
|
||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
||||||
|
|
||||||
from backend.data.invited_user import (
|
|
||||||
BulkInvitedUserRowResult,
|
|
||||||
BulkInvitedUsersResult,
|
|
||||||
InvitedUserRecord,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .user_admin_routes import router as user_admin_router
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(user_admin_router)
|
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_app_admin_auth(mock_jwt_admin):
|
|
||||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_invited_user() -> InvitedUserRecord:
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
return InvitedUserRecord(
|
|
||||||
id="invite-1",
|
|
||||||
email="invited@example.com",
|
|
||||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
|
||||||
auth_user_id=None,
|
|
||||||
name="Invited User",
|
|
||||||
tally_understanding=None,
|
|
||||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
|
||||||
tally_computed_at=None,
|
|
||||||
tally_error=None,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
|
||||||
return BulkInvitedUsersResult(
|
|
||||||
created_count=1,
|
|
||||||
skipped_count=1,
|
|
||||||
error_count=0,
|
|
||||||
results=[
|
|
||||||
BulkInvitedUserRowResult(
|
|
||||||
row_number=1,
|
|
||||||
email="invited@example.com",
|
|
||||||
name=None,
|
|
||||||
status="CREATED",
|
|
||||||
message="Invite created",
|
|
||||||
invited_user=_sample_invited_user(),
|
|
||||||
),
|
|
||||||
BulkInvitedUserRowResult(
|
|
||||||
row_number=2,
|
|
||||||
email="duplicate@example.com",
|
|
||||||
name=None,
|
|
||||||
status="SKIPPED",
|
|
||||||
message="An invited user with this email already exists",
|
|
||||||
invited_user=None,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_invited_users(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
|
||||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/admin/invited-users")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["invited_users"]) == 1
|
|
||||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
|
||||||
assert data["invited_users"][0]["status"] == "INVITED"
|
|
||||||
assert data["pagination"]["total_items"] == 1
|
|
||||||
assert data["pagination"]["current_page"] == 1
|
|
||||||
assert data["pagination"]["page_size"] == 50
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_invited_user(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
|
||||||
AsyncMock(return_value=_sample_invited_user()),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/admin/invited-users",
|
|
||||||
json={"email": "invited@example.com", "name": "Invited User"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == "invited@example.com"
|
|
||||||
assert data["name"] == "Invited User"
|
|
||||||
|
|
||||||
|
|
||||||
def test_bulk_create_invited_users(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
|
||||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/admin/invited-users/bulk",
|
|
||||||
files={
|
|
||||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["created_count"] == 1
|
|
||||||
assert data["skipped_count"] == 1
|
|
||||||
assert data["results"][0]["status"] == "CREATED"
|
|
||||||
assert data["results"][1]["status"] == "SKIPPED"
|
|
||||||
|
|
||||||
|
|
||||||
def test_revoke_invited_user(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
revoked = _sample_invited_user().model_copy(
|
|
||||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
|
||||||
AsyncMock(return_value=revoked),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["status"] == "REVOKED"
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_invited_user_tally(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
retried = _sample_invited_user().model_copy(
|
|
||||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
|
||||||
AsyncMock(return_value=retried),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["tally_status"] == "RUNNING"
|
|
||||||
@@ -1,26 +1,20 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Any, Sequence, get_args, get_origin
|
from typing import Sequence
|
||||||
|
|
||||||
import prisma
|
import prisma
|
||||||
from prisma.enums import ContentType
|
|
||||||
from prisma.models import mv_suggested_blocks
|
|
||||||
|
|
||||||
import backend.api.features.library.db as library_db
|
import backend.api.features.library.db as library_db
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
import backend.data.block
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
from backend.blocks._base import (
|
|
||||||
AnyBlockSchema,
|
|
||||||
BlockCategory,
|
|
||||||
BlockInfo,
|
|
||||||
BlockSchema,
|
|
||||||
BlockType,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
|
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||||
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
@@ -28,7 +22,7 @@ from backend.util.models import Pagination
|
|||||||
from .model import (
|
from .model import (
|
||||||
BlockCategoryResponse,
|
BlockCategoryResponse,
|
||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockTypeFilter,
|
BlockType,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
FilterType,
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
@@ -43,16 +37,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
|
|||||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||||
|
|
||||||
# Boost blocks over marketplace agents in search results
|
|
||||||
BLOCK_SCORE_BOOST = 50.0
|
|
||||||
|
|
||||||
# Block IDs to exclude from search results
|
|
||||||
EXCLUDED_BLOCK_IDS = frozenset(
|
|
||||||
{
|
|
||||||
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||||
|
|
||||||
|
|
||||||
@@ -75,8 +59,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
# Skip disabled and excluded blocks
|
# Skip disabled blocks
|
||||||
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
# Skip blocks that don't have categories (all should have at least one)
|
# Skip blocks that don't have categories (all should have at least one)
|
||||||
if not block.categories:
|
if not block.categories:
|
||||||
@@ -104,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
def get_blocks(
|
def get_blocks(
|
||||||
*,
|
*,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
type: BlockTypeFilter | None = None,
|
type: BlockType | None = None,
|
||||||
provider: ProviderName | None = None,
|
provider: ProviderName | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@@ -127,9 +111,6 @@ def get_blocks(
|
|||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
# Skip excluded blocks
|
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't match the category
|
# Skip blocks that don't match the category
|
||||||
if category and category not in {c.name.lower() for c in block.categories}:
|
if category and category not in {c.name.lower() for c in block.categories}:
|
||||||
continue
|
continue
|
||||||
@@ -269,25 +250,14 @@ async def _build_cached_search_results(
|
|||||||
"my_agents": 0,
|
"my_agents": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use hybrid search when query is present, otherwise list all blocks
|
block_results, block_total, integration_total = _collect_block_results(
|
||||||
if (include_blocks or include_integrations) and normalized_query:
|
normalized_query=normalized_query,
|
||||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
include_blocks=include_blocks,
|
||||||
query=search_query,
|
include_integrations=include_integrations,
|
||||||
include_blocks=include_blocks,
|
)
|
||||||
include_integrations=include_integrations,
|
scored_items.extend(block_results)
|
||||||
)
|
total_items["blocks"] = block_total
|
||||||
scored_items.extend(block_results)
|
total_items["integrations"] = integration_total
|
||||||
total_items["blocks"] = block_total
|
|
||||||
total_items["integrations"] = integration_total
|
|
||||||
elif include_blocks or include_integrations:
|
|
||||||
# No query - list all blocks using in-memory approach
|
|
||||||
block_results, block_total, integration_total = _collect_block_results(
|
|
||||||
include_blocks=include_blocks,
|
|
||||||
include_integrations=include_integrations,
|
|
||||||
)
|
|
||||||
scored_items.extend(block_results)
|
|
||||||
total_items["blocks"] = block_total
|
|
||||||
total_items["integrations"] = integration_total
|
|
||||||
|
|
||||||
if include_library_agents:
|
if include_library_agents:
|
||||||
library_response = await library_db.list_library_agents(
|
library_response = await library_db.list_library_agents(
|
||||||
@@ -332,14 +302,10 @@ async def _build_cached_search_results(
|
|||||||
|
|
||||||
def _collect_block_results(
|
def _collect_block_results(
|
||||||
*,
|
*,
|
||||||
|
normalized_query: str,
|
||||||
include_blocks: bool,
|
include_blocks: bool,
|
||||||
include_integrations: bool,
|
include_integrations: bool,
|
||||||
) -> tuple[list[_ScoredItem], int, int]:
|
) -> tuple[list[_ScoredItem], int, int]:
|
||||||
"""
|
|
||||||
Collect all blocks for listing (no search query).
|
|
||||||
|
|
||||||
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
|
|
||||||
"""
|
|
||||||
results: list[_ScoredItem] = []
|
results: list[_ScoredItem] = []
|
||||||
block_count = 0
|
block_count = 0
|
||||||
integration_count = 0
|
integration_count = 0
|
||||||
@@ -352,10 +318,6 @@ def _collect_block_results(
|
|||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip excluded blocks
|
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block_info = block.get_info()
|
block_info = block.get_info()
|
||||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||||
is_integration = len(credentials) > 0
|
is_integration = len(credentials) > 0
|
||||||
@@ -365,6 +327,10 @@ def _collect_block_results(
|
|||||||
if not is_integration and not include_blocks:
|
if not is_integration and not include_blocks:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
score = _score_block(block, block_info, normalized_query)
|
||||||
|
if not _should_include_item(score, normalized_query):
|
||||||
|
continue
|
||||||
|
|
||||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||||
if is_integration:
|
if is_integration:
|
||||||
integration_count += 1
|
integration_count += 1
|
||||||
@@ -375,122 +341,8 @@ def _collect_block_results(
|
|||||||
_ScoredItem(
|
_ScoredItem(
|
||||||
item=block_info,
|
item=block_info,
|
||||||
filter_type=filter_type,
|
filter_type=filter_type,
|
||||||
score=BLOCK_SCORE_BOOST,
|
score=score,
|
||||||
sort_key=block_info.name.lower(),
|
sort_key=_get_item_name(block_info),
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results, block_count, integration_count
|
|
||||||
|
|
||||||
|
|
||||||
async def _hybrid_search_blocks(
|
|
||||||
*,
|
|
||||||
query: str,
|
|
||||||
include_blocks: bool,
|
|
||||||
include_integrations: bool,
|
|
||||||
) -> tuple[list[_ScoredItem], int, int]:
|
|
||||||
"""
|
|
||||||
Search blocks using hybrid search with builder-specific filtering.
|
|
||||||
|
|
||||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
|
||||||
post-filtering for block/integration types and scoring adjustments.
|
|
||||||
|
|
||||||
Scoring:
|
|
||||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
|
||||||
to prioritize blocks over marketplace agents in combined results
|
|
||||||
- +30 for exact name match, +15 for prefix name match
|
|
||||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query string
|
|
||||||
include_blocks: Whether to include regular blocks
|
|
||||||
include_integrations: Whether to include integration blocks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (scored_items, block_count, integration_count)
|
|
||||||
"""
|
|
||||||
results: list[_ScoredItem] = []
|
|
||||||
block_count = 0
|
|
||||||
integration_count = 0
|
|
||||||
|
|
||||||
if not include_blocks and not include_integrations:
|
|
||||||
return results, block_count, integration_count
|
|
||||||
|
|
||||||
normalized_query = query.strip().lower()
|
|
||||||
|
|
||||||
# Fetch more results to account for post-filtering
|
|
||||||
search_results, _ = await unified_hybrid_search(
|
|
||||||
query=query,
|
|
||||||
content_types=[ContentType.BLOCK],
|
|
||||||
page=1,
|
|
||||||
page_size=150,
|
|
||||||
min_score=0.10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load all blocks for getting BlockInfo
|
|
||||||
all_blocks = load_all_blocks()
|
|
||||||
|
|
||||||
for result in search_results:
|
|
||||||
block_id = result["content_id"]
|
|
||||||
|
|
||||||
# Skip excluded blocks
|
|
||||||
if block_id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
metadata = result.get("metadata", {})
|
|
||||||
hybrid_score = result.get("relevance", 0.0)
|
|
||||||
|
|
||||||
# Get the actual block class
|
|
||||||
if block_id not in all_blocks:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block_cls = all_blocks[block_id]
|
|
||||||
block: AnyBlockSchema = block_cls()
|
|
||||||
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check block/integration filter using metadata
|
|
||||||
is_integration = metadata.get("is_integration", False)
|
|
||||||
|
|
||||||
if is_integration and not include_integrations:
|
|
||||||
continue
|
|
||||||
if not is_integration and not include_blocks:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get block info
|
|
||||||
block_info = block.get_info()
|
|
||||||
|
|
||||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
|
||||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
|
||||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
|
||||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
|
||||||
|
|
||||||
# Add LLM model match bonus
|
|
||||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
|
||||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
|
||||||
final_score += 20
|
|
||||||
|
|
||||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
|
||||||
name = block_info.name.lower()
|
|
||||||
if name == normalized_query:
|
|
||||||
final_score += 30
|
|
||||||
elif name.startswith(normalized_query):
|
|
||||||
final_score += 15
|
|
||||||
|
|
||||||
# Track counts
|
|
||||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
|
||||||
if is_integration:
|
|
||||||
integration_count += 1
|
|
||||||
else:
|
|
||||||
block_count += 1
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
_ScoredItem(
|
|
||||||
item=block_info,
|
|
||||||
filter_type=filter_type,
|
|
||||||
score=final_score,
|
|
||||||
sort_key=name,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -615,8 +467,6 @@ async def _get_static_counts():
|
|||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_blocks += 1
|
all_blocks += 1
|
||||||
|
|
||||||
@@ -643,25 +493,47 @@ async def _get_static_counts():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _contains_type(annotation: Any, target: type) -> bool:
|
|
||||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
|
||||||
if annotation is target:
|
|
||||||
return True
|
|
||||||
origin = get_origin(annotation)
|
|
||||||
if origin is None:
|
|
||||||
return False
|
|
||||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
|
||||||
|
|
||||||
|
|
||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||||
for field in schema_cls.model_fields.values():
|
for field in schema_cls.model_fields.values():
|
||||||
if _contains_type(field.annotation, LlmModel):
|
if field.annotation == LlmModel:
|
||||||
# Check if query matches any value in llm_models
|
# Check if query matches any value in llm_models
|
||||||
if any(query in name for name in llm_models):
|
if any(query in name for name in llm_models):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _score_block(
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
block_info: BlockInfo,
|
||||||
|
normalized_query: str,
|
||||||
|
) -> float:
|
||||||
|
if not normalized_query:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
name = block_info.name.lower()
|
||||||
|
description = block_info.description.lower()
|
||||||
|
score = _score_primary_fields(name, description, normalized_query)
|
||||||
|
|
||||||
|
category_text = " ".join(
|
||||||
|
category.get("category", "").lower() for category in block_info.categories
|
||||||
|
)
|
||||||
|
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||||
|
|
||||||
|
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||||
|
provider_names = [
|
||||||
|
provider.value.lower()
|
||||||
|
for info in credentials_info
|
||||||
|
for provider in info.provider
|
||||||
|
]
|
||||||
|
provider_text = " ".join(provider_names)
|
||||||
|
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||||
|
|
||||||
|
if _matches_llm_model(block.input_schema, normalized_query):
|
||||||
|
score += 20
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
def _score_library_agent(
|
def _score_library_agent(
|
||||||
agent: library_model.LibraryAgent,
|
agent: library_model.LibraryAgent,
|
||||||
normalized_query: str,
|
normalized_query: str,
|
||||||
@@ -768,32 +640,45 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
|||||||
return providers
|
return providers
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600, shared_cache=True)
|
@cached(ttl_seconds=3600)
|
||||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||||
"""Return the most-executed blocks from the last 14 days.
|
suggested_blocks = []
|
||||||
|
# Sum the number of executions for each block type
|
||||||
|
# Prisma cannot group by nested relations, so we do a raw query
|
||||||
|
# Calculate the cutoff timestamp
|
||||||
|
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||||
|
|
||||||
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
|
results = await query_raw_with_schema(
|
||||||
and returns the top `count` blocks sorted by execution count, excluding
|
"""
|
||||||
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
|
SELECT
|
||||||
"""
|
agent_node."agentBlockId" AS block_id,
|
||||||
results = await mv_suggested_blocks.prisma().find_many()
|
COUNT(execution.id) AS execution_count
|
||||||
|
FROM {schema_prefix}"AgentNodeExecution" execution
|
||||||
|
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||||
|
WHERE execution."endedTime" >= $1::timestamp
|
||||||
|
GROUP BY agent_node."agentBlockId"
|
||||||
|
ORDER BY execution_count DESC;
|
||||||
|
""",
|
||||||
|
timestamp_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
# Get the top blocks based on execution count
|
# Get the top blocks based on execution count
|
||||||
# But ignore Input, Output, Agent, and excluded blocks
|
# But ignore Input and Output blocks
|
||||||
blocks: list[tuple[BlockInfo, int]] = []
|
blocks: list[tuple[BlockInfo, int]] = []
|
||||||
execution_counts = {row.block_id: row.execution_count for row in results}
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled or block.block_type in (
|
if block.disabled or block.block_type in (
|
||||||
BlockType.INPUT,
|
backend.data.block.BlockType.INPUT,
|
||||||
BlockType.OUTPUT,
|
backend.data.block.BlockType.OUTPUT,
|
||||||
BlockType.AGENT,
|
backend.data.block.BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
# Find the execution count for this block
|
||||||
continue
|
execution_count = next(
|
||||||
execution_count = execution_counts.get(block.id, 0)
|
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||||
|
0,
|
||||||
|
)
|
||||||
blocks.append((block.get_info(), execution_count))
|
blocks.append((block.get_info(), execution_count))
|
||||||
# Sort blocks by execution count
|
# Sort blocks by execution count
|
||||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.blocks._base import BlockInfo
|
from backend.data.block import BlockInfo
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
|||||||
"my_agents",
|
"my_agents",
|
||||||
]
|
]
|
||||||
|
|
||||||
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
BlockType = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
class SearchEntry(BaseModel):
|
class SearchEntry(BaseModel):
|
||||||
@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
|
|||||||
|
|
||||||
# Suggestions
|
# Suggestions
|
||||||
class SuggestionsResponse(BaseModel):
|
class SuggestionsResponse(BaseModel):
|
||||||
|
otto_suggestions: list[str]
|
||||||
recent_searches: list[SearchEntry]
|
recent_searches: list[SearchEntry]
|
||||||
providers: list[ProviderName]
|
providers: list[ProviderName]
|
||||||
top_blocks: list[BlockInfo]
|
top_blocks: list[BlockInfo]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Sequence, cast, get_args
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
|
|||||||
from . import db as builder_db
|
from . import db as builder_db
|
||||||
from . import model as builder_model
|
from . import model as builder_model
|
||||||
|
|
||||||
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = fastapi.APIRouter(
|
router = fastapi.APIRouter(
|
||||||
@@ -51,6 +49,11 @@ async def get_suggestions(
|
|||||||
Get all suggestions for the Blocks Menu.
|
Get all suggestions for the Blocks Menu.
|
||||||
"""
|
"""
|
||||||
return builder_model.SuggestionsResponse(
|
return builder_model.SuggestionsResponse(
|
||||||
|
otto_suggestions=[
|
||||||
|
"What blocks do I need to get started?",
|
||||||
|
"Help me create a list",
|
||||||
|
"Help me feed my data to Google Maps",
|
||||||
|
],
|
||||||
recent_searches=await builder_db.get_recent_searches(user_id),
|
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||||
providers=[
|
providers=[
|
||||||
ProviderName.TWITTER,
|
ProviderName.TWITTER,
|
||||||
@@ -85,7 +88,7 @@ async def get_block_categories(
|
|||||||
)
|
)
|
||||||
async def get_blocks(
|
async def get_blocks(
|
||||||
category: Annotated[str | None, fastapi.Query()] = None,
|
category: Annotated[str | None, fastapi.Query()] = None,
|
||||||
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
||||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||||
@@ -148,7 +151,7 @@ async def get_providers(
|
|||||||
async def search(
|
async def search(
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||||
filter: Annotated[str | None, fastapi.Query()] = None,
|
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
||||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
@@ -157,20 +160,9 @@ async def search(
|
|||||||
"""
|
"""
|
||||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||||
"""
|
"""
|
||||||
# Parse and validate filter parameter
|
# If no filters are provided, then we will return all types
|
||||||
filters: list[builder_model.FilterType]
|
if not filter:
|
||||||
if filter:
|
filter = [
|
||||||
filter_values = [f.strip() for f in filter.split(",")]
|
|
||||||
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
|
|
||||||
if invalid_filters:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
|
|
||||||
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
|
|
||||||
)
|
|
||||||
filters = cast(list[builder_model.FilterType], filter_values)
|
|
||||||
else:
|
|
||||||
filters = [
|
|
||||||
"blocks",
|
"blocks",
|
||||||
"integrations",
|
"integrations",
|
||||||
"marketplace_agents",
|
"marketplace_agents",
|
||||||
@@ -182,7 +174,7 @@ async def search(
|
|||||||
cached_results = await builder_db.get_sorted_search_results(
|
cached_results = await builder_db.get_sorted_search_results(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
filters=filters,
|
filters=filter,
|
||||||
by_creator=by_creator,
|
by_creator=by_creator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -204,7 +196,7 @@ async def search(
|
|||||||
user_id,
|
user_id,
|
||||||
builder_model.SearchEntry(
|
builder_model.SearchEntry(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
filter=filters,
|
filter=filter,
|
||||||
by_creator=by_creator,
|
by_creator=by_creator,
|
||||||
search_id=search_id,
|
search_id=search_id,
|
||||||
),
|
),
|
||||||
|
|||||||
96
autogpt_platform/backend/backend/api/features/chat/config.py
Normal file
96
autogpt_platform/backend/backend/api/features/chat/config.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""Configuration management for chat system."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ChatConfig(BaseSettings):
|
||||||
|
"""Configuration for the chat system."""
|
||||||
|
|
||||||
|
# OpenAI API Configuration
|
||||||
|
model: str = Field(
|
||||||
|
default="anthropic/claude-opus-4.5", description="Default model to use"
|
||||||
|
)
|
||||||
|
title_model: str = Field(
|
||||||
|
default="openai/gpt-4o-mini",
|
||||||
|
description="Model to use for generating session titles (should be fast/cheap)",
|
||||||
|
)
|
||||||
|
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||||
|
base_url: str | None = Field(
|
||||||
|
default="https://openrouter.ai/api/v1",
|
||||||
|
description="Base URL for API (e.g., for OpenRouter)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session TTL Configuration - 12 hours
|
||||||
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
|
# Streaming Configuration
|
||||||
|
max_context_messages: int = Field(
|
||||||
|
default=50, ge=1, le=200, description="Maximum context messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
|
max_agent_schedules: int = Field(
|
||||||
|
default=30, description="Maximum number of agent schedules"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long-running operation configuration
|
||||||
|
long_running_operation_ttl: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Langfuse Prompt Management Configuration
|
||||||
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
|
langfuse_prompt_name: str = Field(
|
||||||
|
default="CoPilot Prompt",
|
||||||
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("api_key", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_api_key(cls, v):
|
||||||
|
"""Get API key from environment if not provided."""
|
||||||
|
if v is None:
|
||||||
|
# Try to get from environment variables
|
||||||
|
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||||
|
v = os.getenv("CHAT_API_KEY")
|
||||||
|
if not v:
|
||||||
|
# Fall back to OPEN_ROUTER_API_KEY
|
||||||
|
v = os.getenv("OPEN_ROUTER_API_KEY")
|
||||||
|
if not v:
|
||||||
|
# Fall back to OPENAI_API_KEY
|
||||||
|
v = os.getenv("OPENAI_API_KEY")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("base_url", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_base_url(cls, v):
|
||||||
|
"""Get base URL from environment if not provided."""
|
||||||
|
if v is None:
|
||||||
|
# Check for OpenRouter or custom base URL
|
||||||
|
v = os.getenv("CHAT_BASE_URL")
|
||||||
|
if not v:
|
||||||
|
v = os.getenv("OPENROUTER_BASE_URL")
|
||||||
|
if not v:
|
||||||
|
v = os.getenv("OPENAI_BASE_URL")
|
||||||
|
if not v:
|
||||||
|
v = "https://openrouter.ai/api/v1"
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Prompt paths for different contexts
|
||||||
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
|
"default": "prompts/chat_system.md",
|
||||||
|
"onboarding": "prompts/onboarding_system.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic config."""
|
||||||
|
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
extra = "ignore" # Ignore extra environment variables
|
||||||
291
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
291
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
"""Database operations for chat sessions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from prisma.models import ChatMessage as PrismaChatMessage
|
||||||
|
from prisma.models import ChatSession as PrismaChatSession
|
||||||
|
from prisma.types import (
|
||||||
|
ChatMessageCreateInput,
|
||||||
|
ChatSessionCreateInput,
|
||||||
|
ChatSessionUpdateInput,
|
||||||
|
ChatSessionWhereInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.data.db import transaction
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||||
|
"""Get a chat session by ID from the database."""
|
||||||
|
session = await PrismaChatSession.prisma().find_unique(
|
||||||
|
where={"id": session_id},
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
if session and session.Messages:
|
||||||
|
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||||
|
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||||
|
session.Messages.sort(key=lambda m: m.sequence)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_chat_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> PrismaChatSession:
|
||||||
|
"""Create a new chat session in the database."""
|
||||||
|
data = ChatSessionCreateInput(
|
||||||
|
id=session_id,
|
||||||
|
userId=user_id,
|
||||||
|
credentials=SafeJson({}),
|
||||||
|
successfulAgentRuns=SafeJson({}),
|
||||||
|
successfulAgentSchedules=SafeJson({}),
|
||||||
|
)
|
||||||
|
return await PrismaChatSession.prisma().create(
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_chat_session(
|
||||||
|
session_id: str,
|
||||||
|
credentials: dict[str, Any] | None = None,
|
||||||
|
successful_agent_runs: dict[str, Any] | None = None,
|
||||||
|
successful_agent_schedules: dict[str, Any] | None = None,
|
||||||
|
total_prompt_tokens: int | None = None,
|
||||||
|
total_completion_tokens: int | None = None,
|
||||||
|
title: str | None = None,
|
||||||
|
) -> PrismaChatSession | None:
|
||||||
|
"""Update a chat session's metadata."""
|
||||||
|
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||||
|
|
||||||
|
if credentials is not None:
|
||||||
|
data["credentials"] = SafeJson(credentials)
|
||||||
|
if successful_agent_runs is not None:
|
||||||
|
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||||
|
if successful_agent_schedules is not None:
|
||||||
|
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||||
|
if total_prompt_tokens is not None:
|
||||||
|
data["totalPromptTokens"] = total_prompt_tokens
|
||||||
|
if total_completion_tokens is not None:
|
||||||
|
data["totalCompletionTokens"] = total_completion_tokens
|
||||||
|
if title is not None:
|
||||||
|
data["title"] = title
|
||||||
|
|
||||||
|
session = await PrismaChatSession.prisma().update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data=data,
|
||||||
|
include={"Messages": True},
|
||||||
|
)
|
||||||
|
if session and session.Messages:
|
||||||
|
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||||
|
session.Messages.sort(key=lambda m: m.sequence)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def add_chat_message(
|
||||||
|
session_id: str,
|
||||||
|
role: str,
|
||||||
|
sequence: int,
|
||||||
|
content: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
refusal: str | None = None,
|
||||||
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
|
function_call: dict[str, Any] | None = None,
|
||||||
|
) -> PrismaChatMessage:
|
||||||
|
"""Add a message to a chat session."""
|
||||||
|
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||||
|
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||||
|
# We only include fields that have values, then cast at the end.
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"Session": {"connect": {"id": session_id}},
|
||||||
|
"role": role,
|
||||||
|
"sequence": sequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional string fields
|
||||||
|
if content is not None:
|
||||||
|
data["content"] = content
|
||||||
|
if name is not None:
|
||||||
|
data["name"] = name
|
||||||
|
if tool_call_id is not None:
|
||||||
|
data["toolCallId"] = tool_call_id
|
||||||
|
if refusal is not None:
|
||||||
|
data["refusal"] = refusal
|
||||||
|
|
||||||
|
# Add optional JSON fields only when they have values
|
||||||
|
if tool_calls is not None:
|
||||||
|
data["toolCalls"] = SafeJson(tool_calls)
|
||||||
|
if function_call is not None:
|
||||||
|
data["functionCall"] = SafeJson(function_call)
|
||||||
|
|
||||||
|
# Run message create and session timestamp update in parallel for lower latency
|
||||||
|
_, message = await asyncio.gather(
|
||||||
|
PrismaChatSession.prisma().update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
|
),
|
||||||
|
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
async def add_chat_messages_batch(
|
||||||
|
session_id: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
start_sequence: int,
|
||||||
|
) -> list[PrismaChatMessage]:
|
||||||
|
"""Add multiple messages to a chat session in a batch.
|
||||||
|
|
||||||
|
Uses a transaction for atomicity - if any message creation fails,
|
||||||
|
the entire batch is rolled back.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
created_messages = []
|
||||||
|
|
||||||
|
async with transaction() as tx:
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||||
|
# directly because Prisma's TypedDict validation rejects optional fields
|
||||||
|
# set to None. We only include fields that have values, then cast.
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"Session": {"connect": {"id": session_id}},
|
||||||
|
"role": msg["role"],
|
||||||
|
"sequence": start_sequence + i,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional string fields
|
||||||
|
if msg.get("content") is not None:
|
||||||
|
data["content"] = msg["content"]
|
||||||
|
if msg.get("name") is not None:
|
||||||
|
data["name"] = msg["name"]
|
||||||
|
if msg.get("tool_call_id") is not None:
|
||||||
|
data["toolCallId"] = msg["tool_call_id"]
|
||||||
|
if msg.get("refusal") is not None:
|
||||||
|
data["refusal"] = msg["refusal"]
|
||||||
|
|
||||||
|
# Add optional JSON fields only when they have values
|
||||||
|
if msg.get("tool_calls") is not None:
|
||||||
|
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||||
|
if msg.get("function_call") is not None:
|
||||||
|
data["functionCall"] = SafeJson(msg["function_call"])
|
||||||
|
|
||||||
|
created = await PrismaChatMessage.prisma(tx).create(
|
||||||
|
data=cast(ChatMessageCreateInput, data)
|
||||||
|
)
|
||||||
|
created_messages.append(created)
|
||||||
|
|
||||||
|
# Update session's updatedAt timestamp within the same transaction.
|
||||||
|
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||||
|
# separately via update_chat_session() after streaming completes.
|
||||||
|
await PrismaChatSession.prisma(tx).update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_messages
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_chat_sessions(
|
||||||
|
user_id: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[PrismaChatSession]:
|
||||||
|
"""Get chat sessions for a user, ordered by most recent."""
|
||||||
|
return await PrismaChatSession.prisma().find_many(
|
||||||
|
where={"userId": user_id},
|
||||||
|
order={"updatedAt": "desc"},
|
||||||
|
take=limit,
|
||||||
|
skip=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_session_count(user_id: str) -> int:
|
||||||
|
"""Get the total number of chat sessions for a user."""
|
||||||
|
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||||
|
"""Delete a chat session and all its messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID to delete.
|
||||||
|
user_id: If provided, validates that the session belongs to this user
|
||||||
|
before deletion. This prevents unauthorized deletion of other
|
||||||
|
users' sessions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Build typed where clause with optional user_id validation
|
||||||
|
where_clause: ChatSessionWhereInput = {"id": session_id}
|
||||||
|
if user_id is not None:
|
||||||
|
where_clause["userId"] = user_id
|
||||||
|
|
||||||
|
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"No session deleted for {session_id} "
|
||||||
|
f"(user_id validation: {user_id is not None})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chat_session_message_count(session_id: str) -> int:
|
||||||
|
"""Get the number of messages in a chat session."""
|
||||||
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tool_message_content(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
new_content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update the content of a tool message in chat history.
|
||||||
|
|
||||||
|
Used by background tasks to update pending operation messages with final results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The chat session ID.
|
||||||
|
tool_call_id: The tool call ID to find the message.
|
||||||
|
new_content: The new content to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a message was updated, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await PrismaChatMessage.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"content": new_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"No message found to update for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update tool message for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Self, cast
|
from typing import Any
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
|||||||
from prisma.models import ChatSession as PrismaChatSession
|
from prisma.models import ChatSession as PrismaChatSession
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.db_accessors import chat_db
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.exceptions import DatabaseError, RedisError
|
from backend.util.exceptions import DatabaseError, RedisError
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||||
|
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
if isinstance(value, str):
|
||||||
|
return json.loads(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
# Redis cache key prefix for chat sessions
|
# Redis cache key prefix for chat sessions
|
||||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||||
|
|
||||||
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
|
|||||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||||
|
|
||||||
|
|
||||||
# ===================== Chat data models ===================== #
|
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||||
|
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||||
|
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||||
|
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||||
|
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||||
|
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||||
|
_session_locks_mutex = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||||
|
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||||
|
|
||||||
|
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||||
|
when no coroutine holds a reference to them, preventing memory leaks from
|
||||||
|
unbounded growth of session locks.
|
||||||
|
"""
|
||||||
|
async with _session_locks_mutex:
|
||||||
|
lock = _session_locks.get(session_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
_session_locks[session_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
|
|||||||
tool_calls: list[dict] | None = None
|
tool_calls: list[dict] | None = None
|
||||||
function_call: dict | None = None
|
function_call: dict | None = None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
|
||||||
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
|
|
||||||
return ChatMessage(
|
|
||||||
role=prisma_message.role,
|
|
||||||
content=prisma_message.content,
|
|
||||||
name=prisma_message.name,
|
|
||||||
tool_call_id=prisma_message.toolCallId,
|
|
||||||
refusal=prisma_message.refusal,
|
|
||||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
|
||||||
function_call=_parse_json_field(prisma_message.functionCall),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
class Usage(BaseModel):
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
@@ -75,10 +92,11 @@ class Usage(BaseModel):
|
|||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class ChatSessionInfo(BaseModel):
|
class ChatSession(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
|
messages: list[ChatMessage]
|
||||||
usage: list[Usage]
|
usage: list[Usage]
|
||||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
@@ -86,9 +104,40 @@ class ChatSessionInfo(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
def new(user_id: str) -> "ChatSession":
|
||||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
return ChatSession(
|
||||||
|
session_id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
title=None,
|
||||||
|
messages=[],
|
||||||
|
usage=[],
|
||||||
|
credentials={},
|
||||||
|
started_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(
|
||||||
|
prisma_session: PrismaChatSession,
|
||||||
|
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||||
|
) -> "ChatSession":
|
||||||
|
"""Convert Prisma models to Pydantic ChatSession."""
|
||||||
|
messages = []
|
||||||
|
if prisma_messages:
|
||||||
|
for msg in prisma_messages:
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role=msg.role,
|
||||||
|
content=msg.content,
|
||||||
|
name=msg.name,
|
||||||
|
tool_call_id=msg.toolCallId,
|
||||||
|
refusal=msg.refusal,
|
||||||
|
tool_calls=_parse_json_field(msg.toolCalls),
|
||||||
|
function_call=_parse_json_field(msg.functionCall),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Parse JSON fields from Prisma
|
# Parse JSON fields from Prisma
|
||||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||||
successful_agent_runs = _parse_json_field(
|
successful_agent_runs = _parse_json_field(
|
||||||
@@ -110,10 +159,11 @@ class ChatSessionInfo(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return ChatSession(
|
||||||
session_id=prisma_session.id,
|
session_id=prisma_session.id,
|
||||||
user_id=prisma_session.userId,
|
user_id=prisma_session.userId,
|
||||||
title=prisma_session.title,
|
title=prisma_session.title,
|
||||||
|
messages=messages,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
started_at=prisma_session.createdAt,
|
started_at=prisma_session.createdAt,
|
||||||
@@ -122,56 +172,6 @@ class ChatSessionInfo(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatSession(ChatSessionInfo):
|
|
||||||
messages: list[ChatMessage]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def new(cls, user_id: str) -> Self:
|
|
||||||
return cls(
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
title=None,
|
|
||||||
messages=[],
|
|
||||||
usage=[],
|
|
||||||
credentials={},
|
|
||||||
started_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
|
||||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
|
||||||
if prisma_session.Messages is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Prisma session {prisma_session.id} is missing Messages relation"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
**ChatSessionInfo.from_db(prisma_session).model_dump(),
|
|
||||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
|
||||||
"""Attach a tool_call to the current turn's assistant message.
|
|
||||||
|
|
||||||
Searches backwards for the most recent assistant message (stopping at
|
|
||||||
any user message boundary). If found, appends the tool_call to it.
|
|
||||||
Otherwise creates a new assistant message with the tool_call.
|
|
||||||
"""
|
|
||||||
for msg in reversed(self.messages):
|
|
||||||
if msg.role == "user":
|
|
||||||
break
|
|
||||||
if msg.role == "assistant":
|
|
||||||
if not msg.tool_calls:
|
|
||||||
msg.tool_calls = []
|
|
||||||
msg.tool_calls.append(tool_call)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.messages.append(
|
|
||||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -258,72 +258,43 @@ class ChatSession(ChatSessionInfo):
|
|||||||
name=message.name or "",
|
name=message.name or "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return self._merge_consecutive_assistant_messages(messages)
|
return messages
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _merge_consecutive_assistant_messages(
|
|
||||||
messages: list[ChatCompletionMessageParam],
|
|
||||||
) -> list[ChatCompletionMessageParam]:
|
|
||||||
"""Merge consecutive assistant messages into single messages.
|
|
||||||
|
|
||||||
Long-running tool flows can create split assistant messages: one with
|
|
||||||
text content and another with tool_calls. Anthropic's API requires
|
|
||||||
tool_result blocks to reference a tool_use in the immediately preceding
|
|
||||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
|
||||||
"""
|
|
||||||
if len(messages) < 2:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
|
||||||
for msg in messages[1:]:
|
|
||||||
prev = result[-1]
|
|
||||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
|
||||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
|
||||||
|
|
||||||
curr_content = curr.get("content") or ""
|
|
||||||
if curr_content:
|
|
||||||
prev_content = prev.get("content") or ""
|
|
||||||
prev["content"] = (
|
|
||||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_tool_calls = curr.get("tool_calls")
|
|
||||||
if curr_tool_calls:
|
|
||||||
prev_tool_calls = prev.get("tool_calls")
|
|
||||||
prev["tool_calls"] = (
|
|
||||||
list(prev_tool_calls) + list(curr_tool_calls)
|
|
||||||
if prev_tool_calls
|
|
||||||
else list(curr_tool_calls)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
"""Get a chat session from Redis cache."""
|
||||||
if value is None:
|
redis_key = _get_session_cache_key(session_id)
|
||||||
return default
|
async_redis = await get_redis_async()
|
||||||
if isinstance(value, str):
|
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||||
return json.loads(value)
|
|
||||||
return value
|
if raw_session is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
|
logger.info(
|
||||||
|
f"Loading session {session_id} from cache: "
|
||||||
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||||
|
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||||
|
|
||||||
|
|
||||||
# ================ Chat cache + DB operations ================ #
|
async def _cache_session(session: ChatSession) -> None:
|
||||||
|
"""Cache a chat session in Redis."""
|
||||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
|
||||||
# connected directly.
|
|
||||||
|
|
||||||
|
|
||||||
async def cache_chat_session(session: ChatSession) -> None:
|
|
||||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
|
||||||
redis_key = _get_session_cache_key(session.session_id)
|
redis_key = _get_session_cache_key(session.session_id)
|
||||||
async_redis = await get_redis_async()
|
async_redis = await get_redis_async()
|
||||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
|
async def cache_chat_session(session: ChatSession) -> None:
|
||||||
|
"""Cache a chat session without persisting to the database."""
|
||||||
|
await _cache_session(session)
|
||||||
|
|
||||||
|
|
||||||
async def invalidate_session_cache(session_id: str) -> None:
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
"""Invalidate a chat session from Redis cache.
|
"""Invalidate a chat session from Redis cache.
|
||||||
|
|
||||||
@@ -339,6 +310,80 @@ async def invalidate_session_cache(session_id: str) -> None:
|
|||||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
|
"""Get a chat session from the database."""
|
||||||
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
if not prisma_session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages = prisma_session.Messages
|
||||||
|
logger.info(
|
||||||
|
f"Loading session {session_id} from DB: "
|
||||||
|
f"has_messages={messages is not None}, "
|
||||||
|
f"message_count={len(messages) if messages else 0}, "
|
||||||
|
f"roles={[m.role for m in messages] if messages else []}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_session_to_db(
|
||||||
|
session: ChatSession, existing_message_count: int
|
||||||
|
) -> None:
|
||||||
|
"""Save or update a chat session in the database."""
|
||||||
|
# Check if session exists in DB
|
||||||
|
existing = await chat_db.get_chat_session(session.session_id)
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
# Create new session
|
||||||
|
await chat_db.create_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
existing_message_count = 0
|
||||||
|
|
||||||
|
# Calculate total tokens from usage
|
||||||
|
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||||
|
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||||
|
|
||||||
|
# Update session metadata
|
||||||
|
await chat_db.update_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
credentials=session.credentials,
|
||||||
|
successful_agent_runs=session.successful_agent_runs,
|
||||||
|
successful_agent_schedules=session.successful_agent_schedules,
|
||||||
|
total_prompt_tokens=total_prompt,
|
||||||
|
total_completion_tokens=total_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add new messages (only those after existing count)
|
||||||
|
new_messages = session.messages[existing_message_count:]
|
||||||
|
if new_messages:
|
||||||
|
messages_data = []
|
||||||
|
for msg in new_messages:
|
||||||
|
messages_data.append(
|
||||||
|
{
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content,
|
||||||
|
"name": msg.name,
|
||||||
|
"tool_call_id": msg.tool_call_id,
|
||||||
|
"refusal": msg.refusal,
|
||||||
|
"tool_calls": msg.tool_calls,
|
||||||
|
"function_call": msg.function_call,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||||
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
|
f"start_sequence={existing_message_count}"
|
||||||
|
)
|
||||||
|
await chat_db.add_chat_messages_batch(
|
||||||
|
session_id=session.session_id,
|
||||||
|
messages=messages_data,
|
||||||
|
start_sequence=existing_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(
|
async def get_chat_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
@@ -370,7 +415,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
logger.info(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -386,7 +431,7 @@ async def get_chat_session(
|
|||||||
|
|
||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await cache_chat_session(session)
|
await _cache_session(session)
|
||||||
logger.info(f"Cached session {session_id} from database")
|
logger.info(f"Cached session {session_id} from database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
@@ -394,44 +439,6 @@ async def get_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|
||||||
"""Get a chat session from Redis cache."""
|
|
||||||
redis_key = _get_session_cache_key(session_id)
|
|
||||||
async_redis = await get_redis_async()
|
|
||||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
|
||||||
|
|
||||||
if raw_session is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
|
||||||
logger.info(
|
|
||||||
f"Loading session {session_id} from cache: "
|
|
||||||
f"message_count={len(session.messages)}, "
|
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
|
||||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|
||||||
"""Get a chat session from the database."""
|
|
||||||
session = await chat_db().get_chat_session(session_id)
|
|
||||||
if not session:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Loaded session {session_id} from DB: "
|
|
||||||
f"has_messages={bool(session.messages)}, "
|
|
||||||
f"message_count={len(session.messages)}, "
|
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def upsert_chat_session(
|
async def upsert_chat_session(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
) -> ChatSession:
|
) -> ChatSession:
|
||||||
@@ -451,35 +458,25 @@ async def upsert_chat_session(
|
|||||||
lock = await _get_session_lock(session.session_id)
|
lock = await _get_session_lock(session.session_id)
|
||||||
|
|
||||||
async with lock:
|
async with lock:
|
||||||
# Always query DB for existing message count to ensure consistency
|
# Get existing message count from DB for incremental saves
|
||||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||||
|
session.session_id
|
||||||
|
)
|
||||||
|
|
||||||
db_error: Exception | None = None
|
db_error: Exception | None = None
|
||||||
|
|
||||||
# Save to database (primary storage)
|
# Save to database (primary storage)
|
||||||
try:
|
try:
|
||||||
await _save_session_to_db(
|
await _save_session_to_db(session, existing_message_count)
|
||||||
session,
|
|
||||||
existing_message_count,
|
|
||||||
skip_existence_check=existing_message_count > 0,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to save session {session.session_id} to database: {e}"
|
f"Failed to save session {session.session_id} to database: {e}"
|
||||||
)
|
)
|
||||||
db_error = e
|
db_error = e
|
||||||
|
|
||||||
# Save to cache (best-effort, even if DB failed).
|
# Save to cache (best-effort, even if DB failed)
|
||||||
# Title updates (update_session_title) run *outside* this lock because
|
|
||||||
# they only touch the title field, not messages. So a concurrent rename
|
|
||||||
# or auto-title may have written a newer title to Redis while this
|
|
||||||
# upsert was in progress. Always prefer the cached title to avoid
|
|
||||||
# overwriting it with the stale in-memory copy.
|
|
||||||
try:
|
try:
|
||||||
existing_cached = await _get_session_from_cache(session.session_id)
|
await _cache_session(session)
|
||||||
if existing_cached and existing_cached.title:
|
|
||||||
session = session.model_copy(update={"title": existing_cached.title})
|
|
||||||
await cache_chat_session(session)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If DB succeeded but cache failed, raise cache error
|
# If DB succeeded but cache failed, raise cache error
|
||||||
if db_error is None:
|
if db_error is None:
|
||||||
@@ -500,107 +497,6 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def _save_session_to_db(
|
|
||||||
session: ChatSession,
|
|
||||||
existing_message_count: int,
|
|
||||||
*,
|
|
||||||
skip_existence_check: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Save or update a chat session in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
|
||||||
and assume the session row already exists. Saves one DB round trip
|
|
||||||
for incremental saves during streaming.
|
|
||||||
"""
|
|
||||||
db = chat_db()
|
|
||||||
|
|
||||||
if not skip_existence_check:
|
|
||||||
# Check if session exists in DB
|
|
||||||
existing = await db.get_chat_session(session.session_id)
|
|
||||||
|
|
||||||
if not existing:
|
|
||||||
# Create new session
|
|
||||||
await db.create_chat_session(
|
|
||||||
session_id=session.session_id,
|
|
||||||
user_id=session.user_id,
|
|
||||||
)
|
|
||||||
existing_message_count = 0
|
|
||||||
|
|
||||||
# Calculate total tokens from usage
|
|
||||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
|
||||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
|
||||||
|
|
||||||
# Update session metadata
|
|
||||||
await db.update_chat_session(
|
|
||||||
session_id=session.session_id,
|
|
||||||
credentials=session.credentials,
|
|
||||||
successful_agent_runs=session.successful_agent_runs,
|
|
||||||
successful_agent_schedules=session.successful_agent_schedules,
|
|
||||||
total_prompt_tokens=total_prompt,
|
|
||||||
total_completion_tokens=total_completion,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add new messages (only those after existing count)
|
|
||||||
new_messages = session.messages[existing_message_count:]
|
|
||||||
if new_messages:
|
|
||||||
messages_data = []
|
|
||||||
for msg in new_messages:
|
|
||||||
messages_data.append(
|
|
||||||
{
|
|
||||||
"role": msg.role,
|
|
||||||
"content": msg.content,
|
|
||||||
"name": msg.name,
|
|
||||||
"tool_call_id": msg.tool_call_id,
|
|
||||||
"refusal": msg.refusal,
|
|
||||||
"tool_calls": msg.tool_calls,
|
|
||||||
"function_call": msg.function_call,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
|
||||||
f"roles={[m['role'] for m in messages_data]}, "
|
|
||||||
f"start_sequence={existing_message_count}"
|
|
||||||
)
|
|
||||||
await db.add_chat_messages_batch(
|
|
||||||
session_id=session.session_id,
|
|
||||||
messages=messages_data,
|
|
||||||
start_sequence=existing_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
|
||||||
"""Atomically append a message to a session and persist it.
|
|
||||||
|
|
||||||
Acquires the session lock, re-fetches the latest session state,
|
|
||||||
appends the message, and saves — preventing message loss when
|
|
||||||
concurrent requests modify the same session.
|
|
||||||
"""
|
|
||||||
lock = await _get_session_lock(session_id)
|
|
||||||
|
|
||||||
async with lock:
|
|
||||||
session = await get_chat_session(session_id)
|
|
||||||
if session is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
session.messages.append(message)
|
|
||||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _save_session_to_db(session, existing_message_count)
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabaseError(
|
|
||||||
f"Failed to persist message to session {session_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
await cache_chat_session(session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(user_id: str) -> ChatSession:
|
async def create_chat_session(user_id: str) -> ChatSession:
|
||||||
"""Create a new chat session and persist it.
|
"""Create a new chat session and persist it.
|
||||||
|
|
||||||
@@ -613,7 +509,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
|||||||
|
|
||||||
# Create in database first - fail fast if this fails
|
# Create in database first - fail fast if this fails
|
||||||
try:
|
try:
|
||||||
await chat_db().create_chat_session(
|
await chat_db.create_chat_session(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@@ -625,7 +521,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
|||||||
|
|
||||||
# Cache the session (best-effort optimization, DB is source of truth)
|
# Cache the session (best-effort optimization, DB is source of truth)
|
||||||
try:
|
try:
|
||||||
await cache_chat_session(session)
|
await _cache_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||||
|
|
||||||
@@ -636,16 +532,20 @@ async def get_user_sessions(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> tuple[list[ChatSessionInfo], int]:
|
) -> tuple[list[ChatSession], int]:
|
||||||
"""Get chat sessions for a user from the database with total count.
|
"""Get chat sessions for a user from the database with total count.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (sessions, total_count) where total_count is the overall
|
A tuple of (sessions, total_count) where total_count is the overall
|
||||||
number of sessions for the user (not just the current page).
|
number of sessions for the user (not just the current page).
|
||||||
"""
|
"""
|
||||||
db = chat_db()
|
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
total_count = await chat_db.get_user_session_count(user_id)
|
||||||
total_count = await db.get_user_session_count(user_id)
|
|
||||||
|
sessions = []
|
||||||
|
for prisma_session in prisma_sessions:
|
||||||
|
# Convert without messages for listing (lighter weight)
|
||||||
|
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||||
|
|
||||||
return sessions, total_count
|
return sessions, total_count
|
||||||
|
|
||||||
@@ -663,7 +563,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
|||||||
"""
|
"""
|
||||||
# Delete from database first (with optional user_id validation)
|
# Delete from database first (with optional user_id validation)
|
||||||
# This confirms ownership before invalidating cache
|
# This confirms ownership before invalidating cache
|
||||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||||
|
|
||||||
if not deleted:
|
if not deleted:
|
||||||
return False
|
return False
|
||||||
@@ -680,89 +580,38 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
|||||||
async with _session_locks_mutex:
|
async with _session_locks_mutex:
|
||||||
_session_locks.pop(session_id, None)
|
_session_locks.pop(session_id, None)
|
||||||
|
|
||||||
# Shut down any local browser daemon for this session (best-effort).
|
|
||||||
# Inline import required: all tool modules import ChatSession from this
|
|
||||||
# module, so any top-level import from tools.* would create a cycle.
|
|
||||||
try:
|
|
||||||
from .tools.agent_browser import close_browser_session
|
|
||||||
|
|
||||||
await close_browser_session(session_id, user_id=user_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Browser cleanup for session {session_id}: {e}")
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def update_session_title(
|
async def update_session_title(session_id: str, title: str) -> bool:
|
||||||
session_id: str,
|
"""Update only the title of a chat session.
|
||||||
user_id: str,
|
|
||||||
title: str,
|
|
||||||
*,
|
|
||||||
only_if_empty: bool = False,
|
|
||||||
) -> bool:
|
|
||||||
"""Update the title of a chat session, scoped to the owning user.
|
|
||||||
|
|
||||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
This is a lightweight operation that doesn't touch messages, avoiding
|
||||||
with concurrent message updates.
|
race conditions with concurrent message updates. Use this for background
|
||||||
|
title generation instead of upsert_chat_session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The session ID to update.
|
session_id: The session ID to update.
|
||||||
user_id: Owning user — the DB query filters on this.
|
|
||||||
title: The new title to set.
|
title: The new title to set.
|
||||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
|
||||||
so auto-generated titles never overwrite a user-set title.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if updated successfully, False otherwise (not found, wrong user,
|
True if updated successfully, False otherwise.
|
||||||
or — when only_if_empty — title was already set).
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
updated = await chat_db().update_chat_session_title(
|
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||||
session_id, user_id, title, only_if_empty=only_if_empty
|
if result is None:
|
||||||
)
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
if not updated:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Update title in cache if it exists (instead of invalidating).
|
# Invalidate cache so next fetch gets updated title
|
||||||
# This prevents race conditions where cache invalidation causes
|
|
||||||
# the frontend to see stale DB data while streaming is still in progress.
|
|
||||||
try:
|
try:
|
||||||
cached = await _get_session_from_cache(session_id)
|
redis_key = _get_session_cache_key(session_id)
|
||||||
if cached:
|
async_redis = await get_redis_async()
|
||||||
cached.title = title
|
await async_redis.delete(redis_key)
|
||||||
await cache_chat_session(cached)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# ==================== Chat session locks ==================== #
|
|
||||||
|
|
||||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
|
||||||
_session_locks_mutex = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
|
||||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
|
||||||
|
|
||||||
This was originally added to solve the specific problem of race conditions between
|
|
||||||
the session title thread and the conversation thread, which always occurs on the
|
|
||||||
same instance as we prevent rapid request sends on the frontend.
|
|
||||||
|
|
||||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
|
||||||
when no coroutine holds a reference to them, preventing memory leaks from
|
|
||||||
unbounded growth of session locks. Explicit cleanup also occurs
|
|
||||||
in `delete_chat_session()`.
|
|
||||||
"""
|
|
||||||
async with _session_locks_mutex:
|
|
||||||
lock = _session_locks.get(session_id)
|
|
||||||
if lock is None:
|
|
||||||
lock = asyncio.Lock()
|
|
||||||
_session_locks[session_id] = lock
|
|
||||||
return lock
|
|
||||||
119
autogpt_platform/backend/backend/api/features/chat/model_test.py
Normal file
119
autogpt_platform/backend/backend/api/features/chat/model_test.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from .model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
Usage,
|
||||||
|
get_chat_session,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="Hello, how are you?", role="user"),
|
||||||
|
ChatMessage(
|
||||||
|
content="I'm fine, thank you!",
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "t123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "New York"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
content="I'm using the tool to get the weather",
|
||||||
|
role="tool",
|
||||||
|
tool_call_id="t123",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_chatsession_serialization_deserialization():
|
||||||
|
s = ChatSession.new(user_id="abc123")
|
||||||
|
s.messages = messages
|
||||||
|
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||||
|
serialized = s.model_dump_json()
|
||||||
|
s2 = ChatSession.model_validate_json(serialized)
|
||||||
|
assert s2.model_dump() == s.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||||
|
|
||||||
|
s = ChatSession.new(user_id=test_user_id)
|
||||||
|
s.messages = messages
|
||||||
|
|
||||||
|
s = await upsert_chat_session(s)
|
||||||
|
|
||||||
|
s2 = await get_chat_session(
|
||||||
|
session_id=s.session_id,
|
||||||
|
user_id=s.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert s2 == s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_chatsession_redis_storage_user_id_mismatch(
|
||||||
|
setup_test_user, test_user_id
|
||||||
|
):
|
||||||
|
|
||||||
|
s = ChatSession.new(user_id=test_user_id)
|
||||||
|
s.messages = messages
|
||||||
|
s = await upsert_chat_session(s)
|
||||||
|
|
||||||
|
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||||
|
|
||||||
|
assert s2 is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||||
|
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
# Create session with messages including assistant message
|
||||||
|
s = ChatSession.new(user_id=test_user_id)
|
||||||
|
s.messages = messages # Contains user, assistant, and tool messages
|
||||||
|
assert s.session_id is not None, "Session id is not set"
|
||||||
|
# Upsert to save to both cache and DB
|
||||||
|
s = await upsert_chat_session(s)
|
||||||
|
|
||||||
|
# Clear the Redis cache to force DB load
|
||||||
|
redis_key = f"chat:session:{s.session_id}"
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
|
||||||
|
# Load from DB (cache was cleared)
|
||||||
|
s2 = await get_chat_session(
|
||||||
|
session_id=s.session_id,
|
||||||
|
user_id=s.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert s2 is not None, "Session not found after loading from DB"
|
||||||
|
assert len(s2.messages) == len(
|
||||||
|
s.messages
|
||||||
|
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||||
|
|
||||||
|
# Verify all roles are present
|
||||||
|
roles = [m.role for m in s2.messages]
|
||||||
|
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||||
|
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||||
|
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||||
|
|
||||||
|
# Verify message content
|
||||||
|
for orig, loaded in zip(s.messages, s2.messages):
|
||||||
|
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||||
|
assert (
|
||||||
|
orig.content == loaded.content
|
||||||
|
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||||
|
if orig.tool_calls:
|
||||||
|
assert (
|
||||||
|
loaded.tool_calls is not None
|
||||||
|
), f"Tool calls missing for {orig.role} message"
|
||||||
|
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||||
@@ -5,18 +5,11 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
|
|||||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.util.json import dumps as json_dumps
|
|
||||||
from backend.util.truncate import truncate
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
class ResponseType(str, Enum):
|
||||||
"""Types of streaming responses following AI SDK protocol."""
|
"""Types of streaming responses following AI SDK protocol."""
|
||||||
@@ -25,10 +18,6 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
# Step lifecycle (one LLM API call within a message)
|
|
||||||
START_STEP = "start-step"
|
|
||||||
FINISH_STEP = "finish-step"
|
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -52,8 +41,7 @@ class StreamBaseResponse(BaseModel):
|
|||||||
|
|
||||||
def to_sse(self) -> str:
|
def to_sse(self) -> str:
|
||||||
"""Convert to SSE format."""
|
"""Convert to SSE format."""
|
||||||
json_str = self.model_dump_json(exclude_none=True)
|
return f"data: {self.model_dump_json()}\n\n"
|
||||||
return f"data: {json_str}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Message Lifecycle ==========
|
# ========== Message Lifecycle ==========
|
||||||
@@ -64,18 +52,6 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
sessionId: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Session ID for SSE reconnection.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
|
|
||||||
data: dict[str, Any] = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"messageId": self.messageId,
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
@@ -84,26 +60,6 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
class StreamStartStep(StreamBaseResponse):
|
|
||||||
"""Start of a step (one LLM API call within a message).
|
|
||||||
|
|
||||||
The AI SDK uses this to add a step-start boundary to message.parts,
|
|
||||||
enabling visual separation between multiple LLM calls in a single message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.START_STEP
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinishStep(StreamBaseResponse):
|
|
||||||
"""End of a step (one LLM API call within a message).
|
|
||||||
|
|
||||||
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
|
||||||
so the next LLM call in a tool-call continuation starts with clean state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.FINISH_STEP
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -151,16 +107,13 @@ class StreamToolInputAvailable(StreamBaseResponse):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
|
|
||||||
|
|
||||||
|
|
||||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||||
"""Tool execution result."""
|
"""Tool execution result."""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Keep these for internal backend use
|
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -168,19 +121,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
"""Truncate oversized outputs after construction."""
|
|
||||||
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
|
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, excluding non-spec fields."""
|
|
||||||
data = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"toolCallId": self.toolCallId,
|
|
||||||
"output": self.output,
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
@@ -204,18 +144,6 @@ class StreamError(StreamBaseResponse):
|
|||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_sse(self) -> str:
|
|
||||||
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
|
||||||
|
|
||||||
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
|
||||||
any extra fields like `code` or `details`.
|
|
||||||
"""
|
|
||||||
data = {
|
|
||||||
"type": self.type.value,
|
|
||||||
"errorText": self.errorText,
|
|
||||||
}
|
|
||||||
return f"data: {json_dumps(data)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
class StreamHeartbeat(StreamBaseResponse):
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
@@ -1,68 +1,22 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
from fastapi import APIRouter, Depends, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from prisma.models import UserWorkspaceFile
|
from pydantic import BaseModel
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
from backend.copilot import service as chat_service
|
|
||||||
from backend.copilot import stream_registry
|
|
||||||
from backend.copilot.config import ChatConfig
|
|
||||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
|
||||||
from backend.copilot.model import (
|
|
||||||
ChatMessage,
|
|
||||||
ChatSession,
|
|
||||||
append_and_save_message,
|
|
||||||
create_chat_session,
|
|
||||||
delete_chat_session,
|
|
||||||
get_chat_session,
|
|
||||||
get_user_sessions,
|
|
||||||
update_session_title,
|
|
||||||
)
|
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
|
||||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
|
||||||
from backend.copilot.tools.models import (
|
|
||||||
AgentDetailsResponse,
|
|
||||||
AgentOutputResponse,
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
AgentsFoundResponse,
|
|
||||||
BlockDetailsResponse,
|
|
||||||
BlockListResponse,
|
|
||||||
BlockOutputResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
DocPageResponse,
|
|
||||||
DocSearchResultsResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
ExecutionStartedResponse,
|
|
||||||
InputValidationErrorResponse,
|
|
||||||
MCPToolOutputResponse,
|
|
||||||
MCPToolsDiscoveredResponse,
|
|
||||||
NeedLoginResponse,
|
|
||||||
NoResultsResponse,
|
|
||||||
SetupRequirementsResponse,
|
|
||||||
SuggestedGoalResponse,
|
|
||||||
UnderstandingUpdatedResponse,
|
|
||||||
)
|
|
||||||
from backend.copilot.tracking import track_user_message
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
from backend.data.understanding import get_business_understanding
|
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from . import service as chat_service
|
||||||
|
from .config import ChatConfig
|
||||||
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -91,9 +45,6 @@ class StreamChatRequest(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
is_user_message: bool = True
|
is_user_message: bool = True
|
||||||
context: dict[str, str] | None = None # {url: str, content: str}
|
context: dict[str, str] | None = None # {url: str, content: str}
|
||||||
file_ids: list[str] | None = Field(
|
|
||||||
default=None, max_length=20
|
|
||||||
) # Workspace file IDs attached to this message
|
|
||||||
|
|
||||||
|
|
||||||
class CreateSessionResponse(BaseModel):
|
class CreateSessionResponse(BaseModel):
|
||||||
@@ -104,13 +55,6 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class ActiveStreamInfo(BaseModel):
|
|
||||||
"""Information about an active stream for reconnection."""
|
|
||||||
|
|
||||||
turn_id: str
|
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -119,7 +63,6 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -129,7 +72,6 @@ class SessionSummaryResponse(BaseModel):
|
|||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str
|
updated_at: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
is_processing: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ListSessionsResponse(BaseModel):
|
class ListSessionsResponse(BaseModel):
|
||||||
@@ -139,27 +81,6 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class CancelSessionResponse(BaseModel):
|
|
||||||
"""Response model for the cancel session endpoint."""
|
|
||||||
|
|
||||||
cancelled: bool
|
|
||||||
reason: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateSessionTitleRequest(BaseModel):
|
|
||||||
"""Request model for updating a session's title."""
|
|
||||||
|
|
||||||
title: str
|
|
||||||
|
|
||||||
@field_validator("title")
|
|
||||||
@classmethod
|
|
||||||
def title_must_not_be_blank(cls, v: str) -> str:
|
|
||||||
stripped = v.strip()
|
|
||||||
if not stripped:
|
|
||||||
raise ValueError("Title must not be blank")
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -188,28 +109,6 @@ async def list_sessions(
|
|||||||
"""
|
"""
|
||||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
# Batch-check Redis for active stream status on each session
|
|
||||||
processing_set: set[str] = set()
|
|
||||||
if sessions:
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
pipe = redis.pipeline(transaction=False)
|
|
||||||
for session in sessions:
|
|
||||||
pipe.hget(
|
|
||||||
f"{config.session_meta_prefix}{session.session_id}",
|
|
||||||
"status",
|
|
||||||
)
|
|
||||||
statuses = await pipe.execute()
|
|
||||||
processing_set = {
|
|
||||||
session.session_id
|
|
||||||
for session, st in zip(sessions, statuses)
|
|
||||||
if st == "running"
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ListSessionsResponse(
|
return ListSessionsResponse(
|
||||||
sessions=[
|
sessions=[
|
||||||
SessionSummaryResponse(
|
SessionSummaryResponse(
|
||||||
@@ -217,7 +116,6 @@ async def list_sessions(
|
|||||||
created_at=session.started_at.isoformat(),
|
created_at=session.started_at.isoformat(),
|
||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
title=session.title,
|
title=session.title,
|
||||||
is_processing=session.session_id in processing_set,
|
|
||||||
)
|
)
|
||||||
for session in sessions
|
for session in sessions
|
||||||
],
|
],
|
||||||
@@ -257,92 +155,6 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/sessions/{session_id}",
|
|
||||||
dependencies=[Security(auth.requires_user)],
|
|
||||||
status_code=204,
|
|
||||||
responses={404: {"description": "Session not found or access denied"}},
|
|
||||||
)
|
|
||||||
async def delete_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
|
||||||
) -> Response:
|
|
||||||
"""
|
|
||||||
Delete a chat session.
|
|
||||||
|
|
||||||
Permanently removes a chat session and all its messages.
|
|
||||||
Only the owner can delete their sessions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID to delete.
|
|
||||||
user_id: The authenticated user's ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
204 No Content on success.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if session not found or not owned by user.
|
|
||||||
"""
|
|
||||||
deleted = await delete_chat_session(session_id, user_id)
|
|
||||||
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Session {session_id} not found or access denied",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Best-effort cleanup of the E2B sandbox (if any).
|
|
||||||
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
|
||||||
e2b_cfg = ChatConfig()
|
|
||||||
if e2b_cfg.e2b_active:
|
|
||||||
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
|
||||||
try:
|
|
||||||
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/sessions/{session_id}/title",
|
|
||||||
summary="Update session title",
|
|
||||||
dependencies=[Security(auth.requires_user)],
|
|
||||||
status_code=200,
|
|
||||||
responses={404: {"description": "Session not found or access denied"}},
|
|
||||||
)
|
|
||||||
async def update_session_title_route(
|
|
||||||
session_id: str,
|
|
||||||
request: UpdateSessionTitleRequest,
|
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Update the title of a chat session.
|
|
||||||
|
|
||||||
Allows the user to rename their chat session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID to update.
|
|
||||||
request: Request body containing the new title.
|
|
||||||
user_id: The authenticated user's ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Status of the update.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if session not found or not owned by user.
|
|
||||||
"""
|
|
||||||
success = await update_session_title(session_id, user_id, request.title)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Session {session_id} not found or access denied",
|
|
||||||
)
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}",
|
"/sessions/{session_id}",
|
||||||
)
|
)
|
||||||
@@ -354,14 +166,13 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -369,25 +180,11 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
|
|
||||||
# Check if there's an active stream for this session
|
|
||||||
active_stream_info = None
|
|
||||||
active_session, last_message_id = await stream_registry.get_active_session(
|
|
||||||
session_id, user_id
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
f"Returning session {session_id}: "
|
||||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
f"message_count={len(messages)}, "
|
||||||
|
f"roles={[m.get('role') for m in messages]}"
|
||||||
)
|
)
|
||||||
if active_session:
|
|
||||||
# Keep the assistant message (including tool_calls) so the frontend can
|
|
||||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
|
||||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
|
||||||
# tool parts without output to state "input-available".
|
|
||||||
active_stream_info = ActiveStreamInfo(
|
|
||||||
turn_id=active_session.turn_id,
|
|
||||||
last_message_id=last_message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
@@ -395,55 +192,9 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
active_stream=active_stream_info,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/sessions/{session_id}/cancel",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def cancel_session_task(
|
|
||||||
session_id: str,
|
|
||||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
|
||||||
) -> CancelSessionResponse:
|
|
||||||
"""Cancel the active streaming task for a session.
|
|
||||||
|
|
||||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
|
||||||
polls Redis until the task status flips from ``running`` or a timeout
|
|
||||||
(5 s) is reached. Returns only after the cancellation is confirmed.
|
|
||||||
"""
|
|
||||||
await _validate_and_get_session(session_id, user_id)
|
|
||||||
|
|
||||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
|
||||||
if not active_session:
|
|
||||||
return CancelSessionResponse(cancelled=True, reason="no_active_session")
|
|
||||||
|
|
||||||
await enqueue_cancel_task(session_id)
|
|
||||||
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
|
||||||
|
|
||||||
# Poll until the executor confirms the task is no longer running.
|
|
||||||
poll_interval = 0.5
|
|
||||||
max_wait = 5.0
|
|
||||||
waited = 0.0
|
|
||||||
while waited < max_wait:
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
waited += poll_interval
|
|
||||||
session_state = await stream_registry.get_session(session_id)
|
|
||||||
if session_state is None or session_state.status != "running":
|
|
||||||
logger.info(
|
|
||||||
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
|
||||||
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
|
|
||||||
)
|
|
||||||
return CancelSessionResponse(cancelled=True)
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
|
|
||||||
)
|
|
||||||
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
|
|
||||||
return CancelSessionResponse(cancelled=True)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
@@ -460,10 +211,6 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
|
||||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
|
||||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
@@ -472,246 +219,41 @@ async def stream_chat_post(
|
|||||||
StreamingResponse: SSE-formatted response chunks.
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
import time
|
|
||||||
|
|
||||||
stream_start_time = time.perf_counter()
|
|
||||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
|
||||||
if user_id:
|
|
||||||
log_meta["user_id"] = user_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
|
||||||
f"user={user_id}, message_len={len(request.message)}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
await _validate_and_get_session(session_id, user_id)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enrich message with file metadata if file_ids are provided.
|
|
||||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
|
||||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
|
||||||
sanitized_file_ids: list[str] | None = None
|
|
||||||
if request.file_ids and user_id:
|
|
||||||
# Filter to valid UUIDs only to prevent DB abuse
|
|
||||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
|
||||||
|
|
||||||
if valid_ids:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Batch query instead of N+1
|
|
||||||
files = await UserWorkspaceFile.prisma().find_many(
|
|
||||||
where={
|
|
||||||
"id": {"in": valid_ids},
|
|
||||||
"workspaceId": workspace.id,
|
|
||||||
"isDeleted": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# Only keep IDs that actually exist in the user's workspace
|
|
||||||
sanitized_file_ids = [wf.id for wf in files] or None
|
|
||||||
file_lines: list[str] = [
|
|
||||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
|
||||||
for wf in files
|
|
||||||
]
|
|
||||||
if file_lines:
|
|
||||||
files_block = (
|
|
||||||
"\n\n[Attached files]\n"
|
|
||||||
+ "\n".join(file_lines)
|
|
||||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
|
||||||
)
|
|
||||||
request.message += files_block
|
|
||||||
|
|
||||||
# Atomically append user message to session BEFORE creating task to avoid
|
|
||||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
|
||||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
|
||||||
# message loss from concurrent requests.
|
|
||||||
if request.message:
|
|
||||||
message = ChatMessage(
|
|
||||||
role="user" if request.is_user_message else "assistant",
|
|
||||||
content=request.message,
|
|
||||||
)
|
|
||||||
if request.is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
message_length=len(request.message),
|
|
||||||
)
|
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
|
||||||
await append_and_save_message(session_id, message)
|
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
|
||||||
turn_id = str(uuid4())
|
|
||||||
log_meta["turn_id"] = turn_id
|
|
||||||
|
|
||||||
session_create_start = time.perf_counter()
|
|
||||||
await stream_registry.create_session(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id="chat_stream",
|
|
||||||
tool_name="chat",
|
|
||||||
turn_id=turn_id,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
|
||||||
subscribe_from_id = "0-0"
|
|
||||||
|
|
||||||
await enqueue_copilot_turn(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
message=request.message,
|
|
||||||
turn_id=turn_id,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
context=request.context,
|
|
||||||
file_ids=sanitized_file_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import time as time_module
|
chunk_count = 0
|
||||||
|
first_chunk_type: str | None = None
|
||||||
event_gen_start = time_module.perf_counter()
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
request.message,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
|
context=request.context,
|
||||||
|
):
|
||||||
|
if chunk_count < 3:
|
||||||
|
logger.info(
|
||||||
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_type": str(chunk.type),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
"Chat stream completed",
|
||||||
f"user={user_id}",
|
extra={
|
||||||
extra={"json_fields": log_meta},
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"first_chunk_type": first_chunk_type,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
subscriber_queue = None
|
# AI SDK protocol termination
|
||||||
first_chunk_yielded = False
|
yield "data: [DONE]\n\n"
|
||||||
chunks_yielded = 0
|
|
||||||
try:
|
|
||||||
# Subscribe from the position we captured before enqueuing
|
|
||||||
# This avoids replaying old messages while catching all new ones
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id=subscribe_from_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
|
||||||
logger.info(
|
|
||||||
"[TIMING] Starting to read from subscriber_queue",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
|
||||||
chunks_yielded += 1
|
|
||||||
|
|
||||||
if not first_chunk_yielded:
|
|
||||||
first_chunk_yielded = True
|
|
||||||
elapsed = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
|
||||||
f"type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
# Check for finish signal
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
|
||||||
f"n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
|
|
||||||
except GeneratorExit:
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
"reason": "client_disconnect",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pass # Client disconnected - background task continues
|
|
||||||
except Exception as e:
|
|
||||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
|
||||||
logger.error(
|
|
||||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Surface error to frontend so it doesn't appear stuck
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
).to_sse()
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
finally:
|
|
||||||
# Unsubscribe when client disconnects or stream ends
|
|
||||||
if subscriber_queue is not None:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_session(
|
|
||||||
session_id, subscriber_queue
|
|
||||||
)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
|
||||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time * 1000,
|
|
||||||
"chunks_yielded": chunks_yielded,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -728,94 +270,63 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def resume_session_stream(
|
async def stream_chat_get(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
is_user_message: bool = Query(default=True),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Resume an active stream for a session.
|
Stream chat responses for a session (GET - legacy endpoint).
|
||||||
|
|
||||||
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||||
Checks for an active (in-progress) task on the session and either replays
|
- Text fragments as they are generated
|
||||||
the full SSE stream or returns 204 No Content if nothing is running.
|
- Tool call UI elements (if invoked)
|
||||||
|
- Tool execution results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
|
message: The user's new message to process.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
|
is_user_message: Whether the message is a user message.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse (SSE) when an active stream exists,
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
or 204 No Content when there is nothing to resume.
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
active_session, last_message_id = await stream_registry.get_active_session(
|
|
||||||
session_id, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not active_session:
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
# Always replay from the beginning ("0-0") on resume.
|
|
||||||
# We can't use last_message_id because it's the latest ID in the backend
|
|
||||||
# stream, not the latest the frontend received — the gap causes lost
|
|
||||||
# messages. The frontend deduplicates replayed content.
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id="0-0",
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
try:
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
while True:
|
session_id,
|
||||||
try:
|
message,
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
is_user_message=is_user_message,
|
||||||
if chunk_count < 3:
|
user_id=user_id,
|
||||||
logger.info(
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
"Resume stream chunk",
|
):
|
||||||
extra={
|
if chunk_count < 3:
|
||||||
"session_id": session_id,
|
logger.info(
|
||||||
"chunk_type": str(chunk.type),
|
"Chat stream chunk",
|
||||||
},
|
extra={
|
||||||
)
|
"session_id": session_id,
|
||||||
if not first_chunk_type:
|
"chunk_type": str(chunk.type),
|
||||||
first_chunk_type = str(chunk.type)
|
},
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except GeneratorExit:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_session(
|
|
||||||
session_id, subscriber_queue
|
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
if not first_chunk_type:
|
||||||
logger.error(
|
first_chunk_type = str(chunk.type)
|
||||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
chunk_count += 1
|
||||||
exc_info=True,
|
yield chunk.to_sse()
|
||||||
)
|
logger.info(
|
||||||
logger.info(
|
"Chat stream completed",
|
||||||
"Resume stream completed",
|
extra={
|
||||||
extra={
|
"session_id": session_id,
|
||||||
"session_id": session_id,
|
"chunk_count": chunk_count,
|
||||||
"n_chunks": chunk_count,
|
"first_chunk_type": first_chunk_type,
|
||||||
"first_chunk_type": first_chunk_type,
|
},
|
||||||
},
|
)
|
||||||
)
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -823,8 +334,8 @@ async def resume_session_stream(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -832,6 +343,7 @@ async def resume_session_stream(
|
|||||||
@router.patch(
|
@router.patch(
|
||||||
"/sessions/{session_id}/assign-user",
|
"/sessions/{session_id}/assign-user",
|
||||||
dependencies=[Security(auth.requires_user)],
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
status_code=200,
|
||||||
)
|
)
|
||||||
async def session_assign_user(
|
async def session_assign_user(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -854,56 +366,6 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# ========== Suggested Prompts ==========
|
|
||||||
|
|
||||||
|
|
||||||
class SuggestedPromptsResponse(BaseModel):
|
|
||||||
"""Response model for user-specific suggested prompts."""
|
|
||||||
|
|
||||||
prompts: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/suggested-prompts",
|
|
||||||
dependencies=[Security(auth.requires_user)],
|
|
||||||
)
|
|
||||||
async def get_suggested_prompts(
|
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
|
||||||
) -> SuggestedPromptsResponse:
|
|
||||||
"""
|
|
||||||
Get LLM-generated suggested prompts for the authenticated user.
|
|
||||||
|
|
||||||
Returns personalized quick-action prompts based on the user's
|
|
||||||
business understanding. Returns an empty list if no custom prompts
|
|
||||||
are available.
|
|
||||||
"""
|
|
||||||
understanding = await get_business_understanding(user_id)
|
|
||||||
if understanding is None:
|
|
||||||
return SuggestedPromptsResponse(prompts=[])
|
|
||||||
|
|
||||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/ttl", status_code=200)
|
|
||||||
async def get_ttl_config() -> dict:
|
|
||||||
"""
|
|
||||||
Get the stream TTL configuration.
|
|
||||||
|
|
||||||
Returns the Time-To-Live settings for chat streams, which determines
|
|
||||||
how long clients can reconnect to an active stream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: TTL configuration with seconds and milliseconds values.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"stream_ttl_seconds": config.stream_ttl,
|
|
||||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -940,43 +402,3 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
|
||||||
|
|
||||||
ToolResponseUnion = (
|
|
||||||
AgentsFoundResponse
|
|
||||||
| NoResultsResponse
|
|
||||||
| AgentDetailsResponse
|
|
||||||
| SetupRequirementsResponse
|
|
||||||
| ExecutionStartedResponse
|
|
||||||
| NeedLoginResponse
|
|
||||||
| ErrorResponse
|
|
||||||
| InputValidationErrorResponse
|
|
||||||
| AgentOutputResponse
|
|
||||||
| UnderstandingUpdatedResponse
|
|
||||||
| AgentPreviewResponse
|
|
||||||
| AgentSavedResponse
|
|
||||||
| ClarificationNeededResponse
|
|
||||||
| SuggestedGoalResponse
|
|
||||||
| BlockListResponse
|
|
||||||
| BlockDetailsResponse
|
|
||||||
| BlockOutputResponse
|
|
||||||
| DocSearchResultsResponse
|
|
||||||
| DocPageResponse
|
|
||||||
| MCPToolsDiscoveredResponse
|
|
||||||
| MCPToolOutputResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/schema/tool-responses",
|
|
||||||
response_model=ToolResponseUnion,
|
|
||||||
include_in_schema=True,
|
|
||||||
summary="[Dummy] Tool response type export for codegen",
|
|
||||||
description="This endpoint is not meant to be called. It exists solely to "
|
|
||||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
|
||||||
)
|
|
||||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
|
||||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
|
||||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
|
||||||
|
|||||||
@@ -1,310 +0,0 @@
|
|||||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
import pytest
|
|
||||||
import pytest_mock
|
|
||||||
|
|
||||||
from backend.api.features.chat import routes as chat_routes
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(chat_routes.router)
|
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_app_auth(mock_jwt_user):
|
|
||||||
"""Setup auth overrides for all tests in this module"""
|
|
||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
||||||
|
|
||||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_update_session_title(
|
|
||||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
|
||||||
):
|
|
||||||
"""Mock update_session_title."""
|
|
||||||
return mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.update_session_title",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=success,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Update title: success ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_success(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
mock_update = _mock_update_session_title(mocker, success=True)
|
|
||||||
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": "My project"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"status": "ok"}
|
|
||||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_trims_whitespace(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
mock_update = _mock_update_session_title(mocker, success=True)
|
|
||||||
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": " trimmed "},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_blank_rejected(
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": " "},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_empty_rejected(
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": ""},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_not_found(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
_mock_update_session_title(mocker, success=False)
|
|
||||||
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": "New name"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chat_rejects_too_many_file_ids():
|
|
||||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
|
||||||
response = client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={
|
|
||||||
"message": "hello",
|
|
||||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
|
||||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes._validate_and_get_session",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.append_and_save_message",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_registry = mocker.MagicMock()
|
|
||||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.stream_registry",
|
|
||||||
mock_registry,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.track_user_message",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
|
||||||
_mock_stream_internals(mocker)
|
|
||||||
# Patch workspace lookup as imported by the routes module
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
|
||||||
return_value=type("W", (), {"id": "ws-1"})(),
|
|
||||||
)
|
|
||||||
mock_prisma = mocker.MagicMock()
|
|
||||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
|
||||||
mocker.patch(
|
|
||||||
"prisma.models.UserWorkspaceFile.prisma",
|
|
||||||
return_value=mock_prisma,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={
|
|
||||||
"message": "hello",
|
|
||||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Should get past validation — 200 streaming response expected
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Non-UUID strings in file_ids should be silently filtered out
|
|
||||||
and NOT passed to the database query."""
|
|
||||||
_mock_stream_internals(mocker)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
|
||||||
return_value=type("W", (), {"id": "ws-1"})(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prisma = mocker.MagicMock()
|
|
||||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
|
||||||
mocker.patch(
|
|
||||||
"prisma.models.UserWorkspaceFile.prisma",
|
|
||||||
return_value=mock_prisma,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
||||||
client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={
|
|
||||||
"message": "hello",
|
|
||||||
"file_ids": [
|
|
||||||
valid_id,
|
|
||||||
"not-a-uuid",
|
|
||||||
"../../../etc/passwd",
|
|
||||||
"",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# The find_many call should only receive the one valid UUID
|
|
||||||
mock_prisma.find_many.assert_called_once()
|
|
||||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
|
||||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
|
||||||
"""The batch query should scope to the user's workspace."""
|
|
||||||
_mock_stream_internals(mocker)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
|
||||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prisma = mocker.MagicMock()
|
|
||||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
|
||||||
mocker.patch(
|
|
||||||
"prisma.models.UserWorkspaceFile.prisma",
|
|
||||||
return_value=mock_prisma,
|
|
||||||
)
|
|
||||||
|
|
||||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
||||||
client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={"message": "hi", "file_ids": [fid]},
|
|
||||||
)
|
|
||||||
|
|
||||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
|
||||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
|
||||||
assert call_kwargs["where"]["isDeleted"] is False
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_get_business_understanding(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
*,
|
|
||||||
return_value=None,
|
|
||||||
):
|
|
||||||
"""Mock get_business_understanding."""
|
|
||||||
return mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_business_understanding",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=return_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_suggested_prompts_returns_prompts(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""User with understanding and prompts gets them back."""
|
|
||||||
mock_understanding = MagicMock()
|
|
||||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
|
||||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
|
||||||
|
|
||||||
response = client.get("/suggested-prompts")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_suggested_prompts_no_understanding(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""User with no understanding gets empty list."""
|
|
||||||
_mock_get_business_understanding(mocker, return_value=None)
|
|
||||||
|
|
||||||
response = client.get("/suggested-prompts")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"prompts": []}
|
|
||||||
|
|
||||||
|
|
||||||
def test_suggested_prompts_empty_prompts(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""User with understanding but no prompts gets empty list."""
|
|
||||||
mock_understanding = MagicMock()
|
|
||||||
mock_understanding.suggested_prompts = []
|
|
||||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
|
||||||
|
|
||||||
response = client.get("/suggested-prompts")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"prompts": []}
|
|
||||||
1971
autogpt_platform/backend/backend/api/features/chat/service.py
Normal file
1971
autogpt_platform/backend/backend/api/features/chat/service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,82 @@
|
|||||||
|
import logging
|
||||||
|
from os import getenv
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from . import service as chat_service
|
||||||
|
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||||
|
from .response_model import (
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||||
|
"""
|
||||||
|
Test the stream_chat_completion function.
|
||||||
|
"""
|
||||||
|
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||||
|
|
||||||
|
session = await create_chat_session(test_user_id)
|
||||||
|
|
||||||
|
has_errors = False
|
||||||
|
has_ended = False
|
||||||
|
assistant_message = ""
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||||
|
):
|
||||||
|
logger.info(chunk)
|
||||||
|
if isinstance(chunk, StreamError):
|
||||||
|
has_errors = True
|
||||||
|
if isinstance(chunk, StreamTextDelta):
|
||||||
|
assistant_message += chunk.delta
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
has_ended = True
|
||||||
|
|
||||||
|
assert has_ended, "Chat completion did not end"
|
||||||
|
assert not has_errors, "Error occurred while streaming chat completion"
|
||||||
|
assert assistant_message, "Assistant message is empty"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||||
|
"""
|
||||||
|
Test the stream_chat_completion function.
|
||||||
|
"""
|
||||||
|
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||||
|
|
||||||
|
session = await create_chat_session(test_user_id)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
|
||||||
|
has_errors = False
|
||||||
|
has_ended = False
|
||||||
|
had_tool_calls = False
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session.session_id,
|
||||||
|
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||||
|
user_id=session.user_id,
|
||||||
|
):
|
||||||
|
logger.info(chunk)
|
||||||
|
if isinstance(chunk, StreamError):
|
||||||
|
has_errors = True
|
||||||
|
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
has_ended = True
|
||||||
|
if isinstance(chunk, StreamToolOutputAvailable):
|
||||||
|
had_tool_calls = True
|
||||||
|
|
||||||
|
assert has_ended, "Chat completion did not end"
|
||||||
|
assert not has_errors, "Error occurred while streaming chat completion"
|
||||||
|
assert had_tool_calls, "Tool calls did not occur"
|
||||||
|
session = await get_chat_session(session.session_id)
|
||||||
|
assert session, "Session not found"
|
||||||
|
assert session.usage, "Usage is empty"
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tracking import track_tool_called
|
||||||
|
|
||||||
|
from .add_understanding import AddUnderstandingTool
|
||||||
|
from .agent_output import AgentOutputTool
|
||||||
|
from .base import BaseTool
|
||||||
|
from .create_agent import CreateAgentTool
|
||||||
|
from .edit_agent import EditAgentTool
|
||||||
|
from .find_agent import FindAgentTool
|
||||||
|
from .find_block import FindBlockTool
|
||||||
|
from .find_library_agent import FindLibraryAgentTool
|
||||||
|
from .get_doc_page import GetDocPageTool
|
||||||
|
from .run_agent import RunAgentTool
|
||||||
|
from .run_block import RunBlockTool
|
||||||
|
from .search_docs import SearchDocsTool
|
||||||
|
from .workspace_files import (
|
||||||
|
DeleteWorkspaceFileTool,
|
||||||
|
ListWorkspaceFilesTool,
|
||||||
|
ReadWorkspaceFileTool,
|
||||||
|
WriteWorkspaceFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Single source of truth for all tools
|
||||||
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
|
"add_understanding": AddUnderstandingTool(),
|
||||||
|
"create_agent": CreateAgentTool(),
|
||||||
|
"edit_agent": EditAgentTool(),
|
||||||
|
"find_agent": FindAgentTool(),
|
||||||
|
"find_block": FindBlockTool(),
|
||||||
|
"find_library_agent": FindLibraryAgentTool(),
|
||||||
|
"run_agent": RunAgentTool(),
|
||||||
|
"run_block": RunBlockTool(),
|
||||||
|
"view_agent_output": AgentOutputTool(),
|
||||||
|
"search_docs": SearchDocsTool(),
|
||||||
|
"get_doc_page": GetDocPageTool(),
|
||||||
|
# Workspace tools for CoPilot file operations
|
||||||
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Export individual tool instances for backwards compatibility
|
||||||
|
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||||
|
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||||
|
|
||||||
|
# Generated from registry for OpenAI API
|
||||||
|
tools: list[ChatCompletionToolParam] = [
|
||||||
|
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool(tool_name: str) -> BaseTool | None:
|
||||||
|
"""Get a tool instance by name."""
|
||||||
|
return TOOL_REGISTRY.get(tool_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_tool(
|
||||||
|
tool_name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
tool_call_id: str,
|
||||||
|
) -> "StreamToolOutputAvailable":
|
||||||
|
"""Execute a tool by name."""
|
||||||
|
tool = get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|
||||||
|
# Track tool call in PostHog
|
||||||
|
logger.info(
|
||||||
|
f"Tracking tool call: tool={tool_name}, user={user_id}, "
|
||||||
|
f"session={session.session_id}, call_id={tool_call_id}"
|
||||||
|
)
|
||||||
|
track_tool_called(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
||||||
@@ -1,46 +1,22 @@
|
|||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
from prisma.types import ProfileCreateInput
|
from prisma.types import ProfileCreateInput
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.copilot.model import ChatSession
|
|
||||||
from backend.data import db as db_module
|
|
||||||
from backend.data.db import prisma
|
from backend.data.db import prisma
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
from backend.data.model import APIKeyCredentials
|
from backend.data.model import APIKeyCredentials
|
||||||
from backend.data.user import get_or_create_user
|
from backend.data.user import get_or_create_user
|
||||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _ensure_db_connected() -> None:
|
|
||||||
"""Ensure the Prisma connection is alive on the current event loop.
|
|
||||||
|
|
||||||
On Python 3.11, the httpx transport inside Prisma can reference a stale
|
|
||||||
(closed) event loop when session-scoped async fixtures are evaluated long
|
|
||||||
after the initial ``server`` fixture connected Prisma. A cheap health-check
|
|
||||||
followed by a reconnect fixes this without affecting other fixtures.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await prisma.query_raw("SELECT 1")
|
|
||||||
except Exception:
|
|
||||||
_logger.info("Prisma connection stale – reconnecting")
|
|
||||||
try:
|
|
||||||
await db_module.disconnect()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await db_module.connect()
|
|
||||||
|
|
||||||
|
|
||||||
def make_session(user_id: str):
|
def make_session(user_id: str):
|
||||||
return ChatSession(
|
return ChatSession(
|
||||||
@@ -55,19 +31,15 @@ def make_session(user_id: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
@pytest.fixture(scope="session")
|
||||||
async def setup_test_data(server):
|
async def setup_test_data():
|
||||||
"""
|
"""
|
||||||
Set up test data for run_agent tests:
|
Set up test data for run_agent tests:
|
||||||
1. Create a test user
|
1. Create a test user
|
||||||
2. Create a test graph (agent input -> agent output)
|
2. Create a test graph (agent input -> agent output)
|
||||||
3. Create a store listing and store listing version
|
3. Create a store listing and store listing version
|
||||||
4. Approve the store listing version
|
4. Approve the store listing version
|
||||||
|
|
||||||
Depends on ``server`` to ensure Prisma is connected.
|
|
||||||
"""
|
"""
|
||||||
await _ensure_db_connected()
|
|
||||||
|
|
||||||
# 1. Create a test user
|
# 1. Create a test user
|
||||||
user_data = {
|
user_data = {
|
||||||
"sub": f"test-user-{uuid.uuid4()}",
|
"sub": f"test-user-{uuid.uuid4()}",
|
||||||
@@ -151,8 +123,8 @@ async def setup_test_data(server):
|
|||||||
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
||||||
store_submission = await store_db.create_store_submission(
|
store_submission = await store_db.create_store_submission(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
graph_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
graph_version=created_graph.version,
|
agent_version=created_graph.version,
|
||||||
slug=unique_slug,
|
slug=unique_slug,
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
description="A simple test agent",
|
description="A simple test agent",
|
||||||
@@ -161,10 +133,10 @@ async def setup_test_data(server):
|
|||||||
image_urls=["https://example.com/image.jpg"],
|
image_urls=["https://example.com/image.jpg"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert store_submission.listing_version_id is not None
|
assert store_submission.store_listing_version_id is not None
|
||||||
# 4. Approve the store listing version
|
# 4. Approve the store listing version
|
||||||
await store_db.review_store_submission(
|
await store_db.review_store_submission(
|
||||||
store_listing_version_id=store_submission.listing_version_id,
|
store_listing_version_id=store_submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
external_comments="Approved for testing",
|
external_comments="Approved for testing",
|
||||||
internal_comments="Test approval",
|
internal_comments="Test approval",
|
||||||
@@ -178,19 +150,15 @@ async def setup_test_data(server):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
@pytest.fixture(scope="session")
|
||||||
async def setup_llm_test_data(server):
|
async def setup_llm_test_data():
|
||||||
"""
|
"""
|
||||||
Set up test data for LLM agent tests:
|
Set up test data for LLM agent tests:
|
||||||
1. Create a test user
|
1. Create a test user
|
||||||
2. Create test OpenAI credentials for the user
|
2. Create test OpenAI credentials for the user
|
||||||
3. Create a test graph with input -> LLM block -> output
|
3. Create a test graph with input -> LLM block -> output
|
||||||
4. Create and approve a store listing
|
4. Create and approve a store listing
|
||||||
|
|
||||||
Depends on ``server`` to ensure Prisma is connected.
|
|
||||||
"""
|
"""
|
||||||
await _ensure_db_connected()
|
|
||||||
|
|
||||||
key = getenv("OPENAI_API_KEY")
|
key = getenv("OPENAI_API_KEY")
|
||||||
if not key:
|
if not key:
|
||||||
return pytest.skip("OPENAI_API_KEY is not set")
|
return pytest.skip("OPENAI_API_KEY is not set")
|
||||||
@@ -321,8 +289,8 @@ async def setup_llm_test_data(server):
|
|||||||
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
||||||
store_submission = await store_db.create_store_submission(
|
store_submission = await store_db.create_store_submission(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
graph_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
graph_version=created_graph.version,
|
agent_version=created_graph.version,
|
||||||
slug=unique_slug,
|
slug=unique_slug,
|
||||||
name="LLM Test Agent",
|
name="LLM Test Agent",
|
||||||
description="An agent with LLM capabilities",
|
description="An agent with LLM capabilities",
|
||||||
@@ -330,9 +298,9 @@ async def setup_llm_test_data(server):
|
|||||||
categories=["testing", "ai"],
|
categories=["testing", "ai"],
|
||||||
image_urls=["https://example.com/image.jpg"],
|
image_urls=["https://example.com/image.jpg"],
|
||||||
)
|
)
|
||||||
assert store_submission.listing_version_id is not None
|
assert store_submission.store_listing_version_id is not None
|
||||||
await store_db.review_store_submission(
|
await store_db.review_store_submission(
|
||||||
store_listing_version_id=store_submission.listing_version_id,
|
store_listing_version_id=store_submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
external_comments="Approved for testing",
|
external_comments="Approved for testing",
|
||||||
internal_comments="Test approval for LLM agent",
|
internal_comments="Test approval for LLM agent",
|
||||||
@@ -347,18 +315,14 @@ async def setup_llm_test_data(server):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
@pytest.fixture(scope="session")
|
||||||
async def setup_firecrawl_test_data(server):
|
async def setup_firecrawl_test_data():
|
||||||
"""
|
"""
|
||||||
Set up test data for Firecrawl agent tests (missing credentials scenario):
|
Set up test data for Firecrawl agent tests (missing credentials scenario):
|
||||||
1. Create a test user (WITHOUT Firecrawl credentials)
|
1. Create a test user (WITHOUT Firecrawl credentials)
|
||||||
2. Create a test graph with input -> Firecrawl block -> output
|
2. Create a test graph with input -> Firecrawl block -> output
|
||||||
3. Create and approve a store listing
|
3. Create and approve a store listing
|
||||||
|
|
||||||
Depends on ``server`` to ensure Prisma is connected.
|
|
||||||
"""
|
"""
|
||||||
await _ensure_db_connected()
|
|
||||||
|
|
||||||
# 1. Create a test user
|
# 1. Create a test user
|
||||||
user_data = {
|
user_data = {
|
||||||
"sub": f"test-user-{uuid.uuid4()}",
|
"sub": f"test-user-{uuid.uuid4()}",
|
||||||
@@ -476,8 +440,8 @@ async def setup_firecrawl_test_data(server):
|
|||||||
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
||||||
store_submission = await store_db.create_store_submission(
|
store_submission = await store_db.create_store_submission(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
graph_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
graph_version=created_graph.version,
|
agent_version=created_graph.version,
|
||||||
slug=unique_slug,
|
slug=unique_slug,
|
||||||
name="Firecrawl Test Agent",
|
name="Firecrawl Test Agent",
|
||||||
description="An agent with Firecrawl integration (no credentials)",
|
description="An agent with Firecrawl integration (no credentials)",
|
||||||
@@ -485,9 +449,9 @@ async def setup_firecrawl_test_data(server):
|
|||||||
categories=["testing", "scraping"],
|
categories=["testing", "scraping"],
|
||||||
image_urls=["https://example.com/image.jpg"],
|
image_urls=["https://example.com/image.jpg"],
|
||||||
)
|
)
|
||||||
assert store_submission.listing_version_id is not None
|
assert store_submission.store_listing_version_id is not None
|
||||||
await store_db.review_store_submission(
|
await store_db.review_store_submission(
|
||||||
store_listing_version_id=store_submission.listing_version_id,
|
store_listing_version_id=store_submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
external_comments="Approved for testing",
|
external_comments="Approved for testing",
|
||||||
internal_comments="Test approval for Firecrawl agent",
|
internal_comments="Test approval for Firecrawl agent",
|
||||||
@@ -3,9 +3,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.db_accessors import understanding_db
|
from backend.data.understanding import (
|
||||||
from backend.data.understanding import BusinessUnderstandingInput
|
BusinessUnderstandingInput,
|
||||||
|
upsert_business_understanding,
|
||||||
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||||
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Upsert with merge
|
# Upsert with merge
|
||||||
understanding = await understanding_db().upsert_business_understanding(
|
understanding = await upsert_business_understanding(user_id, input_data)
|
||||||
user_id, input_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build current understanding summary (filter out empty values)
|
# Build current understanding summary (filter out empty values)
|
||||||
current_understanding = {
|
current_understanding = {
|
||||||
@@ -1,49 +1,54 @@
|
|||||||
"""Agent generator package - Creates agents from natural language."""
|
"""Agent generator package - Creates agents from natural language."""
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
AgentJsonValidationError,
|
AgentJsonValidationError,
|
||||||
AgentSummary,
|
AgentSummary,
|
||||||
DecompositionResult,
|
DecompositionResult,
|
||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
|
decompose_goal,
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
extract_uuids_from_text,
|
extract_uuids_from_text,
|
||||||
|
generate_agent,
|
||||||
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
get_all_relevant_agents_for_generation,
|
get_all_relevant_agents_for_generation,
|
||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
get_library_agent_by_id,
|
get_library_agent_by_id,
|
||||||
get_library_agents_by_ids,
|
|
||||||
get_library_agents_for_generation,
|
get_library_agents_for_generation,
|
||||||
graph_to_json,
|
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
)
|
)
|
||||||
from .errors import get_user_message_for_error
|
from .errors import get_user_message_for_error
|
||||||
from .validation import AgentFixer, AgentValidator
|
from .service import health_check as check_external_service_health
|
||||||
|
from .service import is_external_service_configured
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentFixer",
|
"AgentGeneratorNotConfiguredError",
|
||||||
"AgentValidator",
|
|
||||||
"AgentJsonValidationError",
|
"AgentJsonValidationError",
|
||||||
"AgentSummary",
|
"AgentSummary",
|
||||||
"DecompositionResult",
|
"DecompositionResult",
|
||||||
"DecompositionStep",
|
"DecompositionStep",
|
||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
|
"check_external_service_health",
|
||||||
|
"decompose_goal",
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
"extract_uuids_from_text",
|
"extract_uuids_from_text",
|
||||||
|
"generate_agent",
|
||||||
|
"generate_agent_patch",
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
"get_all_relevant_agents_for_generation",
|
"get_all_relevant_agents_for_generation",
|
||||||
"get_library_agent_by_graph_id",
|
"get_library_agent_by_graph_id",
|
||||||
"get_library_agent_by_id",
|
"get_library_agent_by_id",
|
||||||
"get_library_agents_by_ids",
|
|
||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
"graph_to_json",
|
"is_external_service_configured",
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
"search_marketplace_agents_for_generation",
|
"search_marketplace_agents_for_generation",
|
||||||
@@ -3,17 +3,31 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, NotRequired, TypedDict
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
)
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .helpers import UUID_RE_STR
|
from .service import (
|
||||||
|
decompose_goal_external,
|
||||||
|
generate_agent_external,
|
||||||
|
generate_agent_patch_external,
|
||||||
|
is_external_service_configured,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -72,7 +86,38 @@ class DecompositionResult(TypedDict, total=False):
|
|||||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
|
def _to_dict_list(
|
||||||
|
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||||
|
if agents is None:
|
||||||
|
return None
|
||||||
|
return [dict(a) for a in agents]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_service_configured() -> None:
|
||||||
|
"""Check if the external Agent Generator service is configured.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
raise AgentGeneratorNotConfiguredError(
|
||||||
|
"Agent Generator service is not configured. "
|
||||||
|
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_uuids_from_text(text: str) -> list[str]:
|
def extract_uuids_from_text(text: str) -> list[str]:
|
||||||
@@ -108,9 +153,8 @@ async def get_library_agent_by_id(
|
|||||||
Returns:
|
Returns:
|
||||||
LibraryAgentSummary if found, None otherwise
|
LibraryAgentSummary if found, None otherwise
|
||||||
"""
|
"""
|
||||||
db = library_db()
|
|
||||||
try:
|
try:
|
||||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
return LibraryAgentSummary(
|
return LibraryAgentSummary(
|
||||||
@@ -127,7 +171,7 @@ async def get_library_agent_by_id(
|
|||||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await db.get_library_agent(agent_id, user_id)
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
return LibraryAgentSummary(
|
return LibraryAgentSummary(
|
||||||
@@ -154,36 +198,6 @@ async def get_library_agent_by_id(
|
|||||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
async def get_library_agents_by_ids(
|
|
||||||
user_id: str,
|
|
||||||
agent_ids: list[str],
|
|
||||||
) -> list[LibraryAgentSummary]:
|
|
||||||
"""Fetch multiple library agents by their IDs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user ID
|
|
||||||
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of LibraryAgentSummary for found agents (silently skips not found)
|
|
||||||
"""
|
|
||||||
agents: list[LibraryAgentSummary] = []
|
|
||||||
for agent_id in agent_ids:
|
|
||||||
try:
|
|
||||||
agent = await get_library_agent_by_id(user_id, agent_id)
|
|
||||||
if agent:
|
|
||||||
agents.append(agent)
|
|
||||||
logger.debug(f"Fetched library agent by ID: {agent['name']}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Library agent not found for ID: {agent_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
|
|
||||||
return agents
|
|
||||||
|
|
||||||
|
|
||||||
async def get_library_agents_for_generation(
|
async def get_library_agents_for_generation(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
@@ -208,17 +222,10 @@ async def get_library_agents_for_generation(
|
|||||||
Returns:
|
Returns:
|
||||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||||
"""
|
"""
|
||||||
search_term = search_query.strip() if search_query else None
|
|
||||||
if search_term and len(search_term) > 100:
|
|
||||||
raise ValueError(
|
|
||||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
|
||||||
f"Please use a shorter, more specific search term."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await library_db().list_library_agents(
|
response = await library_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=search_term,
|
search_term=search_query,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=max_results,
|
page_size=max_results,
|
||||||
include_executions=True,
|
include_executions=True,
|
||||||
@@ -259,58 +266,37 @@ async def get_library_agents_for_generation(
|
|||||||
async def search_marketplace_agents_for_generation(
|
async def search_marketplace_agents_for_generation(
|
||||||
search_query: str,
|
search_query: str,
|
||||||
max_results: int = 10,
|
max_results: int = 10,
|
||||||
) -> list[LibraryAgentSummary]:
|
) -> list[MarketplaceAgentSummary]:
|
||||||
"""Search marketplace agents formatted for Agent Generator.
|
"""Search marketplace agents formatted for Agent Generator.
|
||||||
|
|
||||||
Fetches marketplace agents and their full schemas so they can be used
|
Note: This returns basic agent info. Full input/output schemas would require
|
||||||
as sub-agents in generated workflows.
|
additional graph fetches and is a potential future enhancement.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_query: Search term to find relevant public agents
|
search_query: Search term to find relevant public agents
|
||||||
max_results: Maximum number of agents to return (default 10)
|
max_results: Maximum number of agents to return (default 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of LibraryAgentSummary with full input/output schemas
|
List of MarketplaceAgentSummary (without detailed schemas for now)
|
||||||
"""
|
"""
|
||||||
search_term = search_query.strip()
|
|
||||||
if len(search_term) > 100:
|
|
||||||
raise ValueError(
|
|
||||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
|
||||||
f"Please use a shorter, more specific search term."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await store_db().get_store_agents(
|
response = await store_db.get_store_agents(
|
||||||
search_query=search_term,
|
search_query=search_query,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=max_results,
|
page_size=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
agents_with_graphs = [
|
results: list[MarketplaceAgentSummary] = []
|
||||||
agent for agent in response.agents if agent.agent_graph_id
|
for agent in response.agents:
|
||||||
]
|
results.append(
|
||||||
|
MarketplaceAgentSummary(
|
||||||
if not agents_with_graphs:
|
name=agent.agent_name,
|
||||||
return []
|
description=agent.description,
|
||||||
|
sub_heading=agent.sub_heading,
|
||||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
creator=agent.creator,
|
||||||
graphs = await graph_db().get_store_listed_graphs(graph_ids)
|
is_marketplace_agent=True,
|
||||||
|
|
||||||
results: list[LibraryAgentSummary] = []
|
|
||||||
for agent in agents_with_graphs:
|
|
||||||
graph_id = agent.agent_graph_id
|
|
||||||
if graph_id and graph_id in graphs:
|
|
||||||
graph = graphs[graph_id]
|
|
||||||
results.append(
|
|
||||||
LibraryAgentSummary(
|
|
||||||
graph_id=graph.id,
|
|
||||||
graph_version=graph.version,
|
|
||||||
name=agent.agent_name,
|
|
||||||
description=agent.description,
|
|
||||||
input_schema=graph.input_schema,
|
|
||||||
output_schema=graph.output_schema,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to search marketplace agents: {e}")
|
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||||
@@ -341,7 +327,8 @@ async def get_all_relevant_agents_for_generation(
|
|||||||
max_marketplace_results: Max marketplace agents to return (default 10)
|
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of AgentSummary with full schemas (both library and marketplace agents)
|
List of AgentSummary, library agents first (with full schemas),
|
||||||
|
then marketplace agents (basic info only)
|
||||||
"""
|
"""
|
||||||
agents: list[AgentSummary] = []
|
agents: list[AgentSummary] = []
|
||||||
seen_graph_ids: set[str] = set()
|
seen_graph_ids: set[str] = set()
|
||||||
@@ -378,11 +365,16 @@ async def get_all_relevant_agents_for_generation(
|
|||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
max_results=max_marketplace_results,
|
max_results=max_marketplace_results,
|
||||||
)
|
)
|
||||||
|
library_names: set[str] = set()
|
||||||
|
for a in agents:
|
||||||
|
name = a.get("name")
|
||||||
|
if name and isinstance(name, str):
|
||||||
|
library_names.add(name.lower())
|
||||||
for agent in marketplace_agents:
|
for agent in marketplace_agents:
|
||||||
graph_id = agent.get("graph_id")
|
agent_name = agent.get("name")
|
||||||
if graph_id and graph_id not in seen_graph_ids:
|
if agent_name and isinstance(agent_name, str):
|
||||||
agents.append(agent)
|
if agent_name.lower() not in library_names:
|
||||||
seen_graph_ids.add(graph_id)
|
agents.append(agent)
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
@@ -432,7 +424,7 @@ def extract_search_terms_from_steps(
|
|||||||
async def enrich_library_agents_from_steps(
|
async def enrich_library_agents_from_steps(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
decomposition_result: DecompositionResult | dict[str, Any],
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
|
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||||
exclude_graph_id: str | None = None,
|
exclude_graph_id: str | None = None,
|
||||||
include_marketplace: bool = True,
|
include_marketplace: bool = True,
|
||||||
max_additional_results: int = 10,
|
max_additional_results: int = 10,
|
||||||
@@ -456,7 +448,7 @@ async def enrich_library_agents_from_steps(
|
|||||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
if not search_terms:
|
if not search_terms:
|
||||||
return list(existing_agents)
|
return existing_agents
|
||||||
|
|
||||||
existing_ids: set[str] = set()
|
existing_ids: set[str] = set()
|
||||||
existing_names: set[str] = set()
|
existing_names: set[str] = set()
|
||||||
@@ -516,6 +508,68 @@ async def enrich_library_agents_from_steps(
|
|||||||
return all_agents
|
return all_agents
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
) -> DecompositionResult | None:
|
||||||
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language goal description
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DecompositionResult with either:
|
||||||
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
|
- {"type": "instructions", "steps": [...]}
|
||||||
|
Or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
|
"""
|
||||||
|
_check_service_configured()
|
||||||
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
|
result = await decompose_goal_external(
|
||||||
|
description, context, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
return result # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent(
|
||||||
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
|
"""
|
||||||
|
_check_service_configured()
|
||||||
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
|
result = await generate_agent_external(
|
||||||
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
return result
|
||||||
|
if "id" not in result:
|
||||||
|
result["id"] = str(uuid.uuid4())
|
||||||
|
if "version" not in result:
|
||||||
|
result["version"] = 1
|
||||||
|
if "is_active" not in result:
|
||||||
|
result["is_active"] = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AgentJsonValidationError(Exception):
|
class AgentJsonValidationError(Exception):
|
||||||
"""Raised when agent JSON is invalid or missing required fields."""
|
"""Raised when agent JSON is invalid or missing required fields."""
|
||||||
|
|
||||||
@@ -594,11 +648,47 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reassign_node_ids(graph: Graph) -> None:
|
||||||
|
"""Reassign all node and link IDs to new UUIDs.
|
||||||
|
|
||||||
|
This is needed when creating a new version to avoid unique constraint violations.
|
||||||
|
"""
|
||||||
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node.id = id_map[node.id]
|
||||||
|
|
||||||
|
for link in graph.links:
|
||||||
|
link.id = str(uuid.uuid4())
|
||||||
|
if link.source_id in id_map:
|
||||||
|
link.source_id = id_map[link.source_id]
|
||||||
|
if link.sink_id in id_map:
|
||||||
|
link.sink_id = id_map[link.sink_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||||
|
"""Populate user_id in AgentExecutorBlock nodes.
|
||||||
|
|
||||||
|
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||||
|
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: Agent JSON dict (modified in place)
|
||||||
|
user_id: User ID to set
|
||||||
|
"""
|
||||||
|
for node in agent_json.get("nodes", []):
|
||||||
|
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||||
|
input_default = node.get("input_default") or {}
|
||||||
|
if not input_default.get("user_id"):
|
||||||
|
input_default["user_id"] = user_id
|
||||||
|
node["input_default"] = input_default
|
||||||
|
logger.debug(
|
||||||
|
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any],
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
user_id: str,
|
|
||||||
is_update: bool = False,
|
|
||||||
folder_id: str | None = None,
|
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
"""Save agent to database and user's library.
|
"""Save agent to database and user's library.
|
||||||
|
|
||||||
@@ -606,27 +696,67 @@ async def save_agent_to_library(
|
|||||||
agent_json: Agent JSON dict
|
agent_json: Agent JSON dict
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
is_update: Whether this is an update to an existing agent
|
is_update: Whether this is an update to an existing agent
|
||||||
folder_id: Optional folder ID to place the agent in
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
|
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||||
|
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
db = library_db()
|
|
||||||
if is_update:
|
if is_update:
|
||||||
return await db.update_graph_in_library(graph, user_id)
|
if graph.id:
|
||||||
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
|
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||||
|
if existing_versions:
|
||||||
|
latest_version = max(v.version for v in existing_versions)
|
||||||
|
graph.version = latest_version + 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||||
|
else:
|
||||||
|
graph.id = str(uuid.uuid4())
|
||||||
|
graph.version = 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Creating new agent with ID {graph.id}")
|
||||||
|
|
||||||
|
created_graph = await create_graph(graph, user_id)
|
||||||
|
|
||||||
|
library_agents = await library_db.create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
async def get_agent_as_json(
|
||||||
"""Convert a Graph object to JSON format for the agent generator.
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Graph object to convert
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -663,32 +793,32 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
async def generate_agent_patch(
|
||||||
agent_id: str, user_id: str | None
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Generating the patch
|
||||||
|
- Applying the patch
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id: Graph ID or library agent ID
|
update_request: Natural language description of changes
|
||||||
user_id: User ID
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
db = graph_db()
|
_check_service_configured()
|
||||||
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
return await generate_agent_patch_external(
|
||||||
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
if not graph and user_id:
|
)
|
||||||
try:
|
|
||||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
|
||||||
graph = await db.get_graph(
|
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return graph_to_json(graph)
|
|
||||||
@@ -0,0 +1,386 @@
|
|||||||
|
"""External Agent Generator service client.
|
||||||
|
|
||||||
|
This module provides a client for communicating with the external Agent Generator
|
||||||
|
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||||
|
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_error_response(
|
||||||
|
error_message: str,
|
||||||
|
error_type: str = "unknown",
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a standardized error response dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: Human-readable error message
|
||||||
|
error_type: Machine-readable error type
|
||||||
|
details: Optional additional error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error dict with type="error" and error details
|
||||||
|
"""
|
||||||
|
response: dict[str, Any] = {
|
||||||
|
"type": "error",
|
||||||
|
"error": error_message,
|
||||||
|
"error_type": error_type,
|
||||||
|
}
|
||||||
|
if details:
|
||||||
|
response["details"] = details
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||||
|
"""Classify an HTTP error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The HTTP status error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
status = e.response.status_code
|
||||||
|
if status == 429:
|
||||||
|
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||||
|
elif status == 503:
|
||||||
|
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||||
|
elif status == 504 or status == 408:
|
||||||
|
return "timeout", f"Agent Generator timed out: {e}"
|
||||||
|
else:
|
||||||
|
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||||
|
"""Classify a request error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The request error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "timeout" in error_str or "timed out" in error_str:
|
||||||
|
return "timeout", f"Agent Generator request timed out: {e}"
|
||||||
|
elif "connect" in error_str:
|
||||||
|
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||||
|
else:
|
||||||
|
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
_client: httpx.AsyncClient | None = None
|
||||||
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_settings() -> Settings:
|
||||||
|
"""Get or create settings singleton."""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def is_external_service_configured() -> bool:
|
||||||
|
"""Check if external Agent Generator service is configured."""
|
||||||
|
settings = _get_settings()
|
||||||
|
return bool(settings.config.agentgenerator_host)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_url() -> str:
|
||||||
|
"""Get the base URL for the external service."""
|
||||||
|
settings = _get_settings()
|
||||||
|
host = settings.config.agentgenerator_host
|
||||||
|
port = settings.config.agentgenerator_port
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> httpx.AsyncClient:
|
||||||
|
"""Get or create the HTTP client for the external service."""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
settings = _get_settings()
|
||||||
|
_client = httpx.AsyncClient(
|
||||||
|
base_url=_get_base_url(),
|
||||||
|
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||||
|
)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal_external(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language goal description
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with either:
|
||||||
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
|
- {"type": "instructions", "steps": [...]}
|
||||||
|
- {"type": "unachievable_goal", ...}
|
||||||
|
- {"type": "vague_goal", ...}
|
||||||
|
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||||
|
Or None on unexpected error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build the request payload
|
||||||
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator decomposition failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
|
# Map the response to the expected format
|
||||||
|
response_type = data.get("type")
|
||||||
|
if response_type == "instructions":
|
||||||
|
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||||
|
elif response_type == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
elif response_type == "unachievable_goal":
|
||||||
|
return {
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": data.get("reason"),
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
elif response_type == "vague_goal":
|
||||||
|
return {
|
||||||
|
"type": "vague_goal",
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
elif response_type == "error":
|
||||||
|
# Pass through error from the service
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Unknown response type from external service: {response_type}"
|
||||||
|
)
|
||||||
|
return _create_error_response(
|
||||||
|
f"Unknown response type from Agent Generator: {response_type}",
|
||||||
|
"invalid_response",
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_type, error_msg = _classify_http_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
error_type, error_msg = _classify_request_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_external(
|
||||||
|
instructions: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_type, error_msg = _classify_http_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
error_type, error_msg = _classify_request_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_patch_external(
|
||||||
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_request: Natural language description of changes
|
||||||
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"update_request": update_request,
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator patch generation failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise return the updated agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_type, error_msg = _classify_http_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
error_type, error_msg = _classify_request_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of block info dicts or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/api/blocks")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error("External service returned error getting blocks")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("blocks", [])
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check() -> bool:
|
||||||
|
"""Check if the external service is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if healthy, False otherwise
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
return False
|
||||||
|
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"External agent generator health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def close_client() -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
global _client
|
||||||
|
if _client is not None:
|
||||||
|
await _client.aclose()
|
||||||
|
_client = None
|
||||||
@@ -5,15 +5,15 @@ import re
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library.model import LibraryAgent
|
from backend.api.features.library.model import LibraryAgent
|
||||||
from backend.copilot.model import ChatSession
|
from backend.data import execution as execution_db
|
||||||
from backend.data.db_accessors import execution_db, library_db
|
|
||||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
|
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -34,7 +34,6 @@ class AgentOutputInput(BaseModel):
|
|||||||
store_slug: str = ""
|
store_slug: str = ""
|
||||||
execution_id: str = ""
|
execution_id: str = ""
|
||||||
run_time: str = "latest"
|
run_time: str = "latest"
|
||||||
wait_if_running: int = Field(default=0, ge=0, le=300)
|
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"agent_name",
|
"agent_name",
|
||||||
@@ -118,11 +117,6 @@ class AgentOutputTool(BaseTool):
|
|||||||
Select which run to retrieve using:
|
Select which run to retrieve using:
|
||||||
- execution_id: Specific execution ID
|
- execution_id: Specific execution ID
|
||||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||||
|
|
||||||
Wait for completion (optional):
|
|
||||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
|
||||||
If the execution is running/queued, waits up to this many seconds for completion.
|
|
||||||
Returns current status on timeout. If already finished, returns immediately.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -152,13 +146,6 @@ class AgentOutputTool(BaseTool):
|
|||||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"wait_if_running": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": (
|
|
||||||
"Max seconds to wait if execution is still running (0-300). "
|
|
||||||
"If running, waits for completion. Returns current state on timeout."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"required": [],
|
"required": [],
|
||||||
}
|
}
|
||||||
@@ -178,12 +165,10 @@ class AgentOutputTool(BaseTool):
|
|||||||
Resolve agent from provided identifiers.
|
Resolve agent from provided identifiers.
|
||||||
Returns (library_agent, error_message).
|
Returns (library_agent, error_message).
|
||||||
"""
|
"""
|
||||||
lib_db = library_db()
|
|
||||||
|
|
||||||
# Priority 1: Exact library agent ID
|
# Priority 1: Exact library agent ID
|
||||||
if library_agent_id:
|
if library_agent_id:
|
||||||
try:
|
try:
|
||||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||||
return agent, None
|
return agent, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||||
@@ -197,7 +182,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||||
|
|
||||||
# Find in user's library by graph_id
|
# Find in user's library by graph_id
|
||||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||||
if not agent:
|
if not agent:
|
||||||
return (
|
return (
|
||||||
None,
|
None,
|
||||||
@@ -209,7 +194,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
# Priority 3: Fuzzy name search in library
|
# Priority 3: Fuzzy name search in library
|
||||||
if agent_name:
|
if agent_name:
|
||||||
try:
|
try:
|
||||||
response = await lib_db.list_library_agents(
|
response = await library_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=agent_name,
|
search_term=agent_name,
|
||||||
page_size=5,
|
page_size=5,
|
||||||
@@ -238,20 +223,14 @@ class AgentOutputTool(BaseTool):
|
|||||||
execution_id: str | None,
|
execution_id: str | None,
|
||||||
time_start: datetime | None,
|
time_start: datetime | None,
|
||||||
time_end: datetime | None,
|
time_end: datetime | None,
|
||||||
include_running: bool = False,
|
|
||||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||||
"""
|
"""
|
||||||
Fetch execution(s) based on filters.
|
Fetch execution(s) based on filters.
|
||||||
Returns (single_execution, available_executions_meta, error_message).
|
Returns (single_execution, available_executions_meta, error_message).
|
||||||
|
|
||||||
Args:
|
|
||||||
include_running: If True, also look for running/queued executions (for waiting)
|
|
||||||
"""
|
"""
|
||||||
exec_db = execution_db()
|
|
||||||
|
|
||||||
# If specific execution_id provided, fetch it directly
|
# If specific execution_id provided, fetch it directly
|
||||||
if execution_id:
|
if execution_id:
|
||||||
execution = await exec_db.get_graph_execution(
|
execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -260,25 +239,11 @@ class AgentOutputTool(BaseTool):
|
|||||||
return None, [], f"Execution '{execution_id}' not found"
|
return None, [], f"Execution '{execution_id}' not found"
|
||||||
return execution, [], None
|
return execution, [], None
|
||||||
|
|
||||||
# Determine which statuses to query
|
# Get completed executions with time filters
|
||||||
statuses = [ExecutionStatus.COMPLETED]
|
executions = await execution_db.get_graph_executions(
|
||||||
if include_running:
|
|
||||||
statuses.extend(
|
|
||||||
[
|
|
||||||
ExecutionStatus.RUNNING,
|
|
||||||
ExecutionStatus.QUEUED,
|
|
||||||
ExecutionStatus.INCOMPLETE,
|
|
||||||
ExecutionStatus.REVIEW,
|
|
||||||
ExecutionStatus.FAILED,
|
|
||||||
ExecutionStatus.TERMINATED,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get executions with time filters
|
|
||||||
executions = await exec_db.get_graph_executions(
|
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
statuses=statuses,
|
statuses=[ExecutionStatus.COMPLETED],
|
||||||
created_time_gte=time_start,
|
created_time_gte=time_start,
|
||||||
created_time_lte=time_end,
|
created_time_lte=time_end,
|
||||||
limit=10,
|
limit=10,
|
||||||
@@ -289,7 +254,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
|
|
||||||
# If only one execution, fetch full details
|
# If only one execution, fetch full details
|
||||||
if len(executions) == 1:
|
if len(executions) == 1:
|
||||||
full_execution = await exec_db.get_graph_execution(
|
full_execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=executions[0].id,
|
execution_id=executions[0].id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -297,7 +262,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return full_execution, [], None
|
return full_execution, [], None
|
||||||
|
|
||||||
# Multiple executions - return latest with full details, plus list of available
|
# Multiple executions - return latest with full details, plus list of available
|
||||||
full_execution = await exec_db.get_graph_execution(
|
full_execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=executions[0].id,
|
execution_id=executions[0].id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -345,33 +310,10 @@ class AgentOutputTool(BaseTool):
|
|||||||
for e in available_executions[:5]
|
for e in available_executions[:5]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Build appropriate message based on execution status
|
message = f"Found execution outputs for agent '{agent.name}'"
|
||||||
if execution.status == ExecutionStatus.COMPLETED:
|
|
||||||
message = f"Found execution outputs for agent '{agent.name}'"
|
|
||||||
elif execution.status == ExecutionStatus.FAILED:
|
|
||||||
message = f"Execution for agent '{agent.name}' failed"
|
|
||||||
elif execution.status == ExecutionStatus.TERMINATED:
|
|
||||||
message = f"Execution for agent '{agent.name}' was terminated"
|
|
||||||
elif execution.status == ExecutionStatus.REVIEW:
|
|
||||||
message = (
|
|
||||||
f"Execution for agent '{agent.name}' is awaiting human review. "
|
|
||||||
"The user needs to approve it before it can continue."
|
|
||||||
)
|
|
||||||
elif execution.status in (
|
|
||||||
ExecutionStatus.RUNNING,
|
|
||||||
ExecutionStatus.QUEUED,
|
|
||||||
ExecutionStatus.INCOMPLETE,
|
|
||||||
):
|
|
||||||
message = (
|
|
||||||
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
|
|
||||||
"Results may be incomplete. Use wait_if_running to wait for completion."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
|
|
||||||
|
|
||||||
if len(available_executions) > 1:
|
if len(available_executions) > 1:
|
||||||
message += (
|
message += (
|
||||||
f" Showing latest of {len(available_executions)} matching executions."
|
f". Showing latest of {len(available_executions)} matching executions."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentOutputResponse(
|
return AgentOutputResponse(
|
||||||
@@ -438,7 +380,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
and not input_data.store_slug
|
and not input_data.store_slug
|
||||||
):
|
):
|
||||||
# Fetch execution directly to get graph_id
|
# Fetch execution directly to get graph_id
|
||||||
execution = await execution_db().get_graph_execution(
|
execution = await execution_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=input_data.execution_id,
|
execution_id=input_data.execution_id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -450,7 +392,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Find library agent by graph_id
|
# Find library agent by graph_id
|
||||||
agent = await library_db().get_library_agent_by_graph_id(
|
agent = await library_db.get_library_agent_by_graph_id(
|
||||||
user_id, execution.graph_id
|
user_id, execution.graph_id
|
||||||
)
|
)
|
||||||
if not agent:
|
if not agent:
|
||||||
@@ -486,17 +428,13 @@ class AgentOutputTool(BaseTool):
|
|||||||
# Parse time expression
|
# Parse time expression
|
||||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||||
|
|
||||||
# Check if we should wait for running executions
|
# Fetch execution(s)
|
||||||
wait_timeout = input_data.wait_if_running
|
|
||||||
|
|
||||||
# Fetch execution(s) - include running if we're going to wait
|
|
||||||
execution, available_executions, exec_error = await self._get_execution(
|
execution, available_executions, exec_error = await self._get_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=agent.graph_id,
|
graph_id=agent.graph_id,
|
||||||
execution_id=input_data.execution_id or None,
|
execution_id=input_data.execution_id or None,
|
||||||
time_start=time_start,
|
time_start=time_start,
|
||||||
time_end=time_end,
|
time_end=time_end,
|
||||||
include_running=wait_timeout > 0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if exec_error:
|
if exec_error:
|
||||||
@@ -505,17 +443,4 @@ class AgentOutputTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we have an execution that's still running and we should wait
|
|
||||||
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
|
|
||||||
logger.info(
|
|
||||||
f"Execution {execution.id} is {execution.status}, "
|
|
||||||
f"waiting up to {wait_timeout}s for completion"
|
|
||||||
)
|
|
||||||
execution = await wait_for_execution(
|
|
||||||
user_id=user_id,
|
|
||||||
graph_id=agent.graph_id,
|
|
||||||
execution_id=execution.id,
|
|
||||||
timeout_seconds=wait_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._build_response(agent, execution, available_executions, session_id)
|
return self._build_response(agent, execution, available_executions, session_id)
|
||||||
@@ -1,15 +1,11 @@
|
|||||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import Literal
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library.model import LibraryAgent
|
from backend.api.features.store import db as store_db
|
||||||
|
|
||||||
from backend.data.db_accessors import library_db, store_db
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -29,24 +25,92 @@ _UUID_PATTERN = re.compile(
|
|||||||
re.IGNORECASE,
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keywords that should be treated as "list all" rather than a literal search
|
|
||||||
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
|
def _is_uuid(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4."""
|
||||||
|
return bool(_UUID_PATTERN.match(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||||
|
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
Tries multiple lookup strategies:
|
||||||
|
1. First by graph_id (AgentGraph primary key)
|
||||||
|
2. Then by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def search_agents(
|
async def search_agents(
|
||||||
query: str,
|
query: str,
|
||||||
source: SearchSource,
|
source: SearchSource,
|
||||||
session_id: str | None = None,
|
session_id: str | None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""
|
"""
|
||||||
Search for agents in marketplace or user library.
|
Search for agents in marketplace or user library.
|
||||||
|
|
||||||
For library searches, keywords like "all", "*", "everything", or an empty
|
|
||||||
query will list all agents without filtering.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Search query string. Special keywords list all library agents.
|
query: Search query string
|
||||||
source: "marketplace" or "library"
|
source: "marketplace" or "library"
|
||||||
session_id: Chat session ID
|
session_id: Chat session ID
|
||||||
user_id: User ID (required for library search)
|
user_id: User ID (required for library search)
|
||||||
@@ -54,11 +118,7 @@ async def search_agents(
|
|||||||
Returns:
|
Returns:
|
||||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||||
"""
|
"""
|
||||||
# Normalize list-all keywords to empty string for library searches
|
if not query:
|
||||||
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
|
|
||||||
query = ""
|
|
||||||
|
|
||||||
if source == "marketplace" and not query:
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a search query", session_id=session_id
|
message="Please provide a search query", session_id=session_id
|
||||||
)
|
)
|
||||||
@@ -73,7 +133,7 @@ async def search_agents(
|
|||||||
try:
|
try:
|
||||||
if source == "marketplace":
|
if source == "marketplace":
|
||||||
logger.info(f"Searching marketplace for: {query}")
|
logger.info(f"Searching marketplace for: {query}")
|
||||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||||
for agent in results.agents:
|
for agent in results.agents:
|
||||||
agents.append(
|
agents.append(
|
||||||
AgentInfo(
|
AgentInfo(
|
||||||
@@ -98,18 +158,28 @@ async def search_agents(
|
|||||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||||
|
|
||||||
if not agents:
|
if not agents:
|
||||||
search_term = query or None
|
logger.info(f"Searching user library for: {query}")
|
||||||
logger.info(
|
results = await library_db.list_library_agents(
|
||||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
|
||||||
f"user library{'' if not query else f' for: {query}'}"
|
|
||||||
)
|
|
||||||
results = await library_db().list_library_agents(
|
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
search_term=search_term,
|
search_term=query,
|
||||||
page_size=50 if not query else 10,
|
page_size=10,
|
||||||
)
|
)
|
||||||
for agent in results.agents:
|
for agent in results.agents:
|
||||||
agents.append(_library_agent_to_info(agent))
|
agents.append(
|
||||||
|
AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(f"Found {len(agents)} agents in {source}")
|
logger.info(f"Found {len(agents)} agents in {source}")
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -122,62 +192,42 @@ async def search_agents(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not agents:
|
if not agents:
|
||||||
if source == "marketplace":
|
suggestions = (
|
||||||
suggestions = [
|
[
|
||||||
"Try more general terms",
|
"Try more general terms",
|
||||||
"Browse categories in the marketplace",
|
"Browse categories in the marketplace",
|
||||||
"Check spelling",
|
"Check spelling",
|
||||||
]
|
]
|
||||||
no_results_msg = (
|
if source == "marketplace"
|
||||||
f"No agents found matching '{query}'. Let the user know they can "
|
else [
|
||||||
"try different keywords or browse the marketplace. Also let them "
|
|
||||||
"know you can create a custom agent for them based on their needs."
|
|
||||||
)
|
|
||||||
elif not query:
|
|
||||||
# User asked to list all but library is empty
|
|
||||||
suggestions = [
|
|
||||||
"Browse the marketplace to find and add agents",
|
|
||||||
"Use find_agent to search the marketplace",
|
|
||||||
]
|
|
||||||
no_results_msg = (
|
|
||||||
"Your library is empty. Let the user know they can browse the "
|
|
||||||
"marketplace to find agents, or you can create a custom agent "
|
|
||||||
"for them based on their needs."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
suggestions = [
|
|
||||||
"Try different keywords",
|
"Try different keywords",
|
||||||
"Use find_agent to search the marketplace",
|
"Use find_agent to search the marketplace",
|
||||||
"Check your library at /library",
|
"Check your library at /library",
|
||||||
]
|
]
|
||||||
no_results_msg = (
|
)
|
||||||
f"No agents matching '{query}' found in your library. Let the "
|
no_results_msg = (
|
||||||
"user know you can create a custom agent for them based on "
|
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||||
"their needs."
|
if source == "marketplace"
|
||||||
)
|
else f"No agents matching '{query}' found in your library."
|
||||||
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
)
|
)
|
||||||
|
|
||||||
if source == "marketplace":
|
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||||
title = (
|
title += (
|
||||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
f"for '{query}'"
|
||||||
)
|
if source == "marketplace"
|
||||||
elif not query:
|
else f"in your library for '{query}'"
|
||||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
|
)
|
||||||
else:
|
|
||||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
|
|
||||||
|
|
||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents. "
|
"Please ask the user if they would like to use any of these agents."
|
||||||
"Let the user know we can create a custom agent for them based on their needs."
|
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||||
"execution results, or run_agent to execute. Let the user know we can "
|
|
||||||
"create a custom agent for them based on their needs."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
@@ -187,70 +237,3 @@ async def search_agents(
|
|||||||
count=len(agents),
|
count=len(agents),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(text: str) -> bool:
|
|
||||||
"""Check if text is a valid UUID v4."""
|
|
||||||
return bool(_UUID_PATTERN.match(text.strip()))
|
|
||||||
|
|
||||||
|
|
||||||
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
|
||||||
"""Convert a library agent model to an AgentInfo."""
|
|
||||||
return AgentInfo(
|
|
||||||
id=agent.id,
|
|
||||||
name=agent.name,
|
|
||||||
description=agent.description or "",
|
|
||||||
source="library",
|
|
||||||
in_library=True,
|
|
||||||
creator=agent.creator_name,
|
|
||||||
status=agent.status.value,
|
|
||||||
can_access_graph=agent.can_access_graph,
|
|
||||||
has_external_trigger=agent.has_external_trigger,
|
|
||||||
new_output=agent.new_output,
|
|
||||||
graph_id=agent.graph_id,
|
|
||||||
graph_version=agent.graph_version,
|
|
||||||
input_schema=agent.input_schema,
|
|
||||||
output_schema=agent.output_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
|
||||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
|
||||||
|
|
||||||
Tries multiple lookup strategies:
|
|
||||||
1. First by graph_id (AgentGraph primary key)
|
|
||||||
2. Then by library agent ID (LibraryAgent primary key)
|
|
||||||
"""
|
|
||||||
lib_db = library_db()
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
|
||||||
if agent:
|
|
||||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
|
||||||
return _library_agent_to_info(agent)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.debug(f"Library agent not found by graph_id: {agent_id}")
|
|
||||||
except DatabaseError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
|
||||||
if agent:
|
|
||||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
|
||||||
return _library_agent_to_info(agent)
|
|
||||||
except NotFoundError:
|
|
||||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
|
||||||
except DatabaseError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
129
autogpt_platform/backend/backend/api/features/chat/tools/base.py
Normal file
129
autogpt_platform/backend/backend/api/features/chat/tools/base.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Base classes and shared utilities for chat tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
|
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool:
|
||||||
|
"""Base class for all chat tools."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Tool name for OpenAI function calling."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
"""Tool description for OpenAI."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
"""Tool parameters schema for OpenAI."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
"""Whether this tool requires authentication."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
"""Whether this tool is long-running and should execute in background.
|
||||||
|
|
||||||
|
Long-running tools (like agent generation) are executed via background
|
||||||
|
tasks to survive SSE disconnections. The result is persisted to chat
|
||||||
|
history and visible when the user refreshes.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||||
|
"""Convert to OpenAI tool format."""
|
||||||
|
return ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function={
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
tool_call_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> StreamToolOutputAvailable:
|
||||||
|
"""Execute the tool with authentication check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID (may be anonymous like "anon_123")
|
||||||
|
session_id: Chat session ID
|
||||||
|
**kwargs: Tool-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pydantic response object
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.requires_auth and not user_id:
|
||||||
|
logger.error(
|
||||||
|
f"Attempted tool call for {self.name} but user not authenticated"
|
||||||
|
)
|
||||||
|
return StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=self.name,
|
||||||
|
output=NeedLoginResponse(
|
||||||
|
message=f"Please sign in to use {self.name}",
|
||||||
|
session_id=session.session_id,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._execute(user_id, session, **kwargs)
|
||||||
|
return StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=self.name,
|
||||||
|
output=result.model_dump_json(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||||
|
return StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=self.name,
|
||||||
|
output=ErrorResponse(
|
||||||
|
message=f"An error occurred while executing {self.name}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session.session_id,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Internal execution logic to be implemented by subclasses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID (authenticated or anonymous)
|
||||||
|
session_id: Chat session ID
|
||||||
|
**kwargs: Tool-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pydantic response object
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -0,0 +1,312 @@
|
|||||||
|
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
|
from .agent_generator import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
|
generate_agent,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
|
save_agent_to_library,
|
||||||
|
)
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
ClarifyingQuestion,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAgentTool(BaseTool):
|
||||||
|
"""Tool for creating agents from natural language descriptions."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "create_agent"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Create a new agent workflow from a natural language description. "
|
||||||
|
"First generates a preview, then saves to library if save=true."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Natural language description of what the agent should do. "
|
||||||
|
"Be specific about inputs, outputs, and the workflow steps."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Additional context or answers to previous clarifying questions. "
|
||||||
|
"Include any preferences or constraints mentioned by the user."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"save": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"Whether to save the agent to the user's library. "
|
||||||
|
"Default is true. Set to false for preview only."
|
||||||
|
),
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["description"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute the create_agent tool.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Decompose the description into steps (may return clarifying questions)
|
||||||
|
2. Generate agent JSON (external service handles fixing and validation)
|
||||||
|
3. Preview or save based on the save parameter
|
||||||
|
"""
|
||||||
|
description = kwargs.get("description", "").strip()
|
||||||
|
context = kwargs.get("context", "")
|
||||||
|
save = kwargs.get("save", True)
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not description:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a description of what the agent should do.",
|
||||||
|
error="Missing description parameter",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=description,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
decomposition_result = await decompose_goal(
|
||||||
|
description, context, library_agents
|
||||||
|
)
|
||||||
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Agent generation is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decomposition_result is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||||
|
error="decomposition_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decomposition_result.get("type") == "error":
|
||||||
|
error_msg = decomposition_result.get("error", "Unknown error")
|
||||||
|
error_type = decomposition_result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="analyze the goal",
|
||||||
|
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"decomposition_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"description": description[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decomposition_result.get("type") == "clarifying_questions":
|
||||||
|
questions = decomposition_result.get("questions", [])
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information to create this agent. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decomposition_result.get("type") == "unachievable_goal":
|
||||||
|
suggested = decomposition_result.get("suggested_goal", "")
|
||||||
|
reason = decomposition_result.get("reason", "")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"This goal cannot be accomplished with the available blocks. "
|
||||||
|
f"{reason} "
|
||||||
|
f"Suggestion: {suggested}"
|
||||||
|
),
|
||||||
|
error="unachievable_goal",
|
||||||
|
details={"suggested_goal": suggested, "reason": reason},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decomposition_result.get("type") == "vague_goal":
|
||||||
|
suggested = decomposition_result.get("suggested_goal", "")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"The goal is too vague to create a specific workflow. "
|
||||||
|
f"Suggestion: {suggested}"
|
||||||
|
),
|
||||||
|
error="vague_goal",
|
||||||
|
details={"suggested_goal": suggested},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id and library_agents is not None:
|
||||||
|
try:
|
||||||
|
library_agents = await enrich_library_agents_from_steps(
|
||||||
|
user_id=user_id,
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=library_agents,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Agent generation is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_json is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||||
|
error="generation_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||||
|
error_msg = agent_json.get("error", "Unknown error")
|
||||||
|
error_type = agent_json.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the agent",
|
||||||
|
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||||
|
validation_message=(
|
||||||
|
"I wasn't able to create a valid agent for this request. "
|
||||||
|
"The generated workflow had some structural issues. "
|
||||||
|
"Please try simplifying your goal or breaking it into smaller steps."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"generation_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"description": description[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
|
agent_description = agent_json.get("description", "")
|
||||||
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
|
if not save:
|
||||||
|
return AgentPreviewResponse(
|
||||||
|
message=(
|
||||||
|
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||||
|
f"Review it and call create_agent with save=true to save it to your library."
|
||||||
|
),
|
||||||
|
agent_json=agent_json,
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
node_count=node_count,
|
||||||
|
link_count=link_count,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="You must be logged in to save agents.",
|
||||||
|
error="auth_required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
agent_json, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentSavedResponse(
|
||||||
|
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_name=created_graph.name,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to save the agent: {str(e)}",
|
||||||
|
error="save_failed",
|
||||||
|
details={"exception": str(e)},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -0,0 +1,261 @@
|
|||||||
|
"""EditAgentTool - Edits existing agents using natural language."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
|
from .agent_generator import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
generate_agent_patch,
|
||||||
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
|
save_agent_to_library,
|
||||||
|
)
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
ClarifyingQuestion,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EditAgentTool(BaseTool):
|
||||||
|
"""Tool for editing existing agents using natural language."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "edit_agent"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Edit an existing agent from the user's library using natural language. "
|
||||||
|
"Generates updates to the agent while preserving unchanged parts."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The ID of the agent to edit. "
|
||||||
|
"Can be a graph ID or library agent ID."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"changes": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Natural language description of what changes to make. "
|
||||||
|
"Be specific about what to add, remove, or modify."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Additional context or answers to previous clarifying questions."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"save": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"Whether to save the changes. "
|
||||||
|
"Default is true. Set to false for preview only."
|
||||||
|
),
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_id", "changes"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute the edit_agent tool.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Fetch the current agent
|
||||||
|
2. Generate updated agent (external service handles fixing and validation)
|
||||||
|
3. Preview or save based on the save parameter
|
||||||
|
"""
|
||||||
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
|
changes = kwargs.get("changes", "").strip()
|
||||||
|
context = kwargs.get("context", "")
|
||||||
|
save = kwargs.get("save", True)
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not agent_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide the agent ID to edit.",
|
||||||
|
error="Missing agent_id parameter",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not changes:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please describe what changes you want to make.",
|
||||||
|
error="Missing changes parameter",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||||
|
|
||||||
|
if current_agent is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||||
|
error="agent_not_found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
graph_id = current_agent.get("id")
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=changes,
|
||||||
|
exclude_graph_id=graph_id,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
|
update_request = changes
|
||||||
|
if context:
|
||||||
|
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await generate_agent_patch(
|
||||||
|
update_request, current_agent, library_agents
|
||||||
|
)
|
||||||
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Agent editing is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||||
|
error="update_generation_failed",
|
||||||
|
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the changes",
|
||||||
|
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||||
|
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"update_generation_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"changes": changes[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("type") == "clarifying_questions":
|
||||||
|
questions = result.get("questions", [])
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information about the changes. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_agent = result
|
||||||
|
|
||||||
|
agent_name = updated_agent.get("name", "Updated Agent")
|
||||||
|
agent_description = updated_agent.get("description", "")
|
||||||
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
|
if not save:
|
||||||
|
return AgentPreviewResponse(
|
||||||
|
message=(
|
||||||
|
f"I've updated the agent. "
|
||||||
|
f"The agent now has {node_count} blocks. "
|
||||||
|
f"Review it and call edit_agent with save=true to save the changes."
|
||||||
|
),
|
||||||
|
agent_json=updated_agent,
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
node_count=node_count,
|
||||||
|
link_count=link_count,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="You must be logged in to save agents.",
|
||||||
|
error="auth_required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
updated_agent, user_id, is_update=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentSavedResponse(
|
||||||
|
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_name=created_graph.name,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to save the updated agent: {str(e)}",
|
||||||
|
error="save_failed",
|
||||||
|
details={"exception": str(e)},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
|
BlockInfoSummary,
|
||||||
|
BlockInputFieldInfo,
|
||||||
|
BlockListResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
NoResultsResponse,
|
||||||
|
)
|
||||||
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
|
from backend.data.block import get_block
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FindBlockTool(BaseTool):
|
||||||
|
"""Tool for searching available blocks."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "find_block"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Search for available blocks by name or description. "
|
||||||
|
"Blocks are reusable components that perform specific tasks like "
|
||||||
|
"sending emails, making API calls, processing text, etc. "
|
||||||
|
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||||
|
"The response includes each block's id, required_inputs, and input_schema."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Search query to find blocks by name or description. "
|
||||||
|
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Search for blocks matching the query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID (required)
|
||||||
|
session: Chat session
|
||||||
|
query: Search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BlockListResponse: List of matching blocks
|
||||||
|
NoResultsResponse: No blocks found
|
||||||
|
ErrorResponse: Error message
|
||||||
|
"""
|
||||||
|
query = kwargs.get("query", "").strip()
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a search query",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Search for blocks using hybrid search
|
||||||
|
results, total = await unified_hybrid_search(
|
||||||
|
query=query,
|
||||||
|
content_types=[ContentType.BLOCK],
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return NoResultsResponse(
|
||||||
|
message=f"No blocks found for '{query}'",
|
||||||
|
suggestions=[
|
||||||
|
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||||
|
"Check spelling of technical terms",
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enrich results with full block information
|
||||||
|
blocks: list[BlockInfoSummary] = []
|
||||||
|
for result in results:
|
||||||
|
block_id = result["content_id"]
|
||||||
|
block = get_block(block_id)
|
||||||
|
|
||||||
|
# Skip disabled blocks
|
||||||
|
if block and not block.disabled:
|
||||||
|
# Get input/output schemas
|
||||||
|
input_schema = {}
|
||||||
|
output_schema = {}
|
||||||
|
try:
|
||||||
|
input_schema = block.input_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get categories from block instance
|
||||||
|
categories = []
|
||||||
|
if hasattr(block, "categories") and block.categories:
|
||||||
|
categories = [cat.value for cat in block.categories]
|
||||||
|
|
||||||
|
# Extract required inputs for easier use
|
||||||
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
|
if input_schema:
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required_fields = set(input_schema.get("required", []))
|
||||||
|
# Get credential field names to exclude from required inputs
|
||||||
|
credentials_fields = set(
|
||||||
|
block.input_schema.get_credentials_fields().keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
# Skip credential fields - they're handled separately
|
||||||
|
if field_name in credentials_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
required_inputs.append(
|
||||||
|
BlockInputFieldInfo(
|
||||||
|
name=field_name,
|
||||||
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(
|
||||||
|
BlockInfoSummary(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not blocks:
|
||||||
|
return NoResultsResponse(
|
||||||
|
message=f"No blocks found for '{query}'",
|
||||||
|
suggestions=[
|
||||||
|
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BlockListResponse(
|
||||||
|
message=(
|
||||||
|
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||||
|
"To execute a block, use run_block with the block's 'id' field "
|
||||||
|
"and provide 'input_data' matching the block's input_schema."
|
||||||
|
),
|
||||||
|
blocks=blocks,
|
||||||
|
count=len(blocks),
|
||||||
|
query=query,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to search blocks",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -19,13 +19,9 @@ class FindLibraryAgentTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Search for or list agents in the user's library. Use this to find "
|
"Search for agents in the user's library. Use this to find agents "
|
||||||
"agents the user has already added to their library, including agents "
|
"the user has already added to their library, including agents they "
|
||||||
"they created or added from the marketplace. "
|
"created or added from the marketplace."
|
||||||
"When creating agents with sub-agent composition, use this to get "
|
|
||||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
|
||||||
"needed for AgentExecutorBlock nodes. "
|
|
||||||
"Omit the query to list all agents."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -35,13 +31,10 @@ class FindLibraryAgentTool(BaseTool):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": (
|
"description": "Search query to find agents by name or description.",
|
||||||
"Search query to find agents by name or description. "
|
|
||||||
"Omit to list all agents in the library."
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": [],
|
"required": ["query"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -52,7 +45,7 @@ class FindLibraryAgentTool(BaseTool):
|
|||||||
self, user_id: str | None, session: ChatSession, **kwargs
|
self, user_id: str | None, session: ChatSession, **kwargs
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
return await search_agents(
|
return await search_agents(
|
||||||
query=(kwargs.get("query") or "").strip(),
|
query=kwargs.get("query", "").strip(),
|
||||||
source="library",
|
source="library",
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -4,10 +4,13 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from .base import BaseTool
|
from backend.api.features.chat.tools.models import (
|
||||||
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
|
DocPageResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -0,0 +1,382 @@
|
|||||||
|
"""Pydantic models for tool responses."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseType(str, Enum):
|
||||||
|
"""Types of tool responses."""
|
||||||
|
|
||||||
|
AGENTS_FOUND = "agents_found"
|
||||||
|
AGENT_DETAILS = "agent_details"
|
||||||
|
SETUP_REQUIREMENTS = "setup_requirements"
|
||||||
|
EXECUTION_STARTED = "execution_started"
|
||||||
|
NEED_LOGIN = "need_login"
|
||||||
|
ERROR = "error"
|
||||||
|
NO_RESULTS = "no_results"
|
||||||
|
AGENT_OUTPUT = "agent_output"
|
||||||
|
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||||
|
AGENT_PREVIEW = "agent_preview"
|
||||||
|
AGENT_SAVED = "agent_saved"
|
||||||
|
CLARIFICATION_NEEDED = "clarification_needed"
|
||||||
|
BLOCK_LIST = "block_list"
|
||||||
|
BLOCK_OUTPUT = "block_output"
|
||||||
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
|
DOC_PAGE = "doc_page"
|
||||||
|
# Workspace response types
|
||||||
|
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||||
|
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||||
|
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||||
|
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||||
|
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||||
|
# Long-running operation types
|
||||||
|
OPERATION_STARTED = "operation_started"
|
||||||
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
|
||||||
|
|
||||||
|
# Base response model
|
||||||
|
class ToolResponseBase(BaseModel):
|
||||||
|
"""Base model for all tool responses."""
|
||||||
|
|
||||||
|
type: ResponseType
|
||||||
|
message: str
|
||||||
|
session_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Agent discovery models
|
||||||
|
class AgentInfo(BaseModel):
|
||||||
|
"""Information about an agent."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
source: str = Field(description="marketplace or library")
|
||||||
|
in_library: bool = False
|
||||||
|
creator: str | None = None
|
||||||
|
category: str | None = None
|
||||||
|
rating: float | None = None
|
||||||
|
runs: int | None = None
|
||||||
|
is_featured: bool | None = None
|
||||||
|
status: str | None = None
|
||||||
|
can_access_graph: bool | None = None
|
||||||
|
has_external_trigger: bool | None = None
|
||||||
|
new_output: bool | None = None
|
||||||
|
graph_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentsFoundResponse(ToolResponseBase):
|
||||||
|
"""Response for find_agent tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENTS_FOUND
|
||||||
|
title: str = "Available Agents"
|
||||||
|
agents: list[AgentInfo]
|
||||||
|
count: int
|
||||||
|
name: str = "agents_found"
|
||||||
|
|
||||||
|
|
||||||
|
class NoResultsResponse(ToolResponseBase):
|
||||||
|
"""Response when no agents found."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.NO_RESULTS
|
||||||
|
suggestions: list[str] = []
|
||||||
|
name: str = "no_results"
|
||||||
|
|
||||||
|
|
||||||
|
# Agent details models
|
||||||
|
class InputField(BaseModel):
|
||||||
|
"""Input field specification."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: str = "string"
|
||||||
|
description: str = ""
|
||||||
|
required: bool = False
|
||||||
|
default: Any | None = None
|
||||||
|
options: list[Any] | None = None
|
||||||
|
format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionOptions(BaseModel):
|
||||||
|
"""Available execution options for an agent."""
|
||||||
|
|
||||||
|
manual: bool = True
|
||||||
|
scheduled: bool = True
|
||||||
|
webhook: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AgentDetails(BaseModel):
|
||||||
|
"""Detailed agent information."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
in_library: bool = False
|
||||||
|
inputs: dict[str, Any] = {}
|
||||||
|
credentials: list[CredentialsMetaInput] = []
|
||||||
|
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
||||||
|
trigger_info: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentDetailsResponse(ToolResponseBase):
|
||||||
|
"""Response for get_details action."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_DETAILS
|
||||||
|
agent: AgentDetails
|
||||||
|
user_authenticated: bool = False
|
||||||
|
graph_id: str | None = None
|
||||||
|
graph_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Setup info models
|
||||||
|
class UserReadiness(BaseModel):
|
||||||
|
"""User readiness status."""
|
||||||
|
|
||||||
|
has_all_credentials: bool = False
|
||||||
|
missing_credentials: dict[str, Any] = {}
|
||||||
|
ready_to_run: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SetupInfo(BaseModel):
|
||||||
|
"""Complete setup information."""
|
||||||
|
|
||||||
|
agent_id: str
|
||||||
|
agent_name: str
|
||||||
|
requirements: dict[str, list[Any]] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"credentials": [],
|
||||||
|
"inputs": [],
|
||||||
|
"execution_modes": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
user_readiness: UserReadiness = Field(default_factory=UserReadiness)
|
||||||
|
|
||||||
|
|
||||||
|
class SetupRequirementsResponse(ToolResponseBase):
|
||||||
|
"""Response for validate action."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.SETUP_REQUIREMENTS
|
||||||
|
setup_info: SetupInfo
|
||||||
|
graph_id: str | None = None
|
||||||
|
graph_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Execution models
|
||||||
|
class ExecutionStartedResponse(ToolResponseBase):
|
||||||
|
"""Response for run/schedule actions."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.EXECUTION_STARTED
|
||||||
|
execution_id: str
|
||||||
|
graph_id: str
|
||||||
|
graph_name: str
|
||||||
|
library_agent_id: str | None = None
|
||||||
|
library_agent_link: str | None = None
|
||||||
|
status: str = "QUEUED"
|
||||||
|
|
||||||
|
|
||||||
|
# Auth/error models
|
||||||
|
class NeedLoginResponse(ToolResponseBase):
|
||||||
|
"""Response when login is needed."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.NEED_LOGIN
|
||||||
|
agent_info: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(ToolResponseBase):
|
||||||
|
"""Response for errors."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.ERROR
|
||||||
|
error: str | None = None
|
||||||
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Agent output models
|
||||||
|
class ExecutionOutputInfo(BaseModel):
|
||||||
|
"""Summary of a single execution's outputs."""
|
||||||
|
|
||||||
|
execution_id: str
|
||||||
|
status: str
|
||||||
|
started_at: datetime | None = None
|
||||||
|
ended_at: datetime | None = None
|
||||||
|
outputs: dict[str, list[Any]]
|
||||||
|
inputs_summary: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutputResponse(ToolResponseBase):
|
||||||
|
"""Response for agent_output tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_OUTPUT
|
||||||
|
agent_name: str
|
||||||
|
agent_id: str
|
||||||
|
library_agent_id: str | None = None
|
||||||
|
library_agent_link: str | None = None
|
||||||
|
execution: ExecutionOutputInfo | None = None
|
||||||
|
available_executions: list[dict[str, Any]] | None = None
|
||||||
|
total_executions: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Business understanding models
|
||||||
|
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||||
|
"""Response for add_understanding tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||||
|
updated_fields: list[str] = Field(default_factory=list)
|
||||||
|
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# Agent generation models
|
||||||
|
class ClarifyingQuestion(BaseModel):
|
||||||
|
"""A question that needs user clarification."""
|
||||||
|
|
||||||
|
question: str
|
||||||
|
keyword: str
|
||||||
|
example: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPreviewResponse(ToolResponseBase):
|
||||||
|
"""Response for previewing a generated agent before saving."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_PREVIEW
|
||||||
|
agent_json: dict[str, Any]
|
||||||
|
agent_name: str
|
||||||
|
description: str
|
||||||
|
node_count: int
|
||||||
|
link_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSavedResponse(ToolResponseBase):
|
||||||
|
"""Response when an agent is saved to the library."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.AGENT_SAVED
|
||||||
|
agent_id: str
|
||||||
|
agent_name: str
|
||||||
|
library_agent_id: str
|
||||||
|
library_agent_link: str
|
||||||
|
agent_page_link: str # Link to the agent builder/editor page
|
||||||
|
|
||||||
|
|
||||||
|
class ClarificationNeededResponse(ToolResponseBase):
|
||||||
|
"""Response when the LLM needs more information from the user."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||||
|
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# Documentation search models
|
||||||
|
class DocSearchResult(BaseModel):
|
||||||
|
"""A single documentation search result."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
path: str
|
||||||
|
section: str
|
||||||
|
snippet: str # Short excerpt for UI display
|
||||||
|
score: float
|
||||||
|
doc_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocSearchResultsResponse(ToolResponseBase):
|
||||||
|
"""Response for search_docs tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||||
|
results: list[DocSearchResult]
|
||||||
|
count: int
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
class DocPageResponse(ToolResponseBase):
|
||||||
|
"""Response for get_doc_page tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.DOC_PAGE
|
||||||
|
title: str
|
||||||
|
path: str
|
||||||
|
content: str # Full document content
|
||||||
|
doc_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Block models
|
||||||
|
class BlockInputFieldInfo(BaseModel):
|
||||||
|
"""Information about a block input field."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
description: str = ""
|
||||||
|
required: bool = False
|
||||||
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BlockInfoSummary(BaseModel):
|
||||||
|
"""Summary of a block for search results."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
categories: list[str]
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
output_schema: dict[str, Any]
|
||||||
|
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of required input fields for this block",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockListResponse(ToolResponseBase):
|
||||||
|
"""Response for find_block tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.BLOCK_LIST
|
||||||
|
blocks: list[BlockInfoSummary]
|
||||||
|
count: int
|
||||||
|
query: str
|
||||||
|
usage_hint: str = Field(
|
||||||
|
default="To execute a block, call run_block with block_id set to the block's "
|
||||||
|
"'id' field and input_data containing the required fields from input_schema."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockOutputResponse(ToolResponseBase):
|
||||||
|
"""Response for run_block tool."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||||
|
block_id: str
|
||||||
|
block_name: str
|
||||||
|
outputs: dict[str, list[Any]]
|
||||||
|
success: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Long-running operation models
|
||||||
|
class OperationStartedResponse(ToolResponseBase):
|
||||||
|
"""Response when a long-running operation has been started in the background.
|
||||||
|
|
||||||
|
This is returned immediately to the client while the operation continues
|
||||||
|
to execute. The user can close the tab and check back later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
|
"""Response stored in chat history while a long-running operation is executing.
|
||||||
|
|
||||||
|
This is persisted to the database so users see a pending state when they
|
||||||
|
refresh before the operation completes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationInProgressResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation is already in progress.
|
||||||
|
|
||||||
|
Returned for idempotency when the same tool_call_id is requested again
|
||||||
|
while the background task is still running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
|
tool_call_id: str
|
||||||
@@ -5,13 +5,16 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from backend.copilot.config import ChatConfig
|
from backend.api.features.chat.config import ChatConfig
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
from backend.api.features.chat.tracking import (
|
||||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
track_agent_run_success,
|
||||||
from backend.data.execution import ExecutionStatus
|
track_agent_scheduled,
|
||||||
|
)
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.user import get_user_by_id
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
@@ -21,17 +24,12 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .execution_utils import get_execution_outputs, wait_for_execution
|
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionOptions,
|
ExecutionOptions,
|
||||||
ExecutionOutputInfo,
|
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
InputValidationErrorResponse,
|
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -70,7 +68,6 @@ class RunAgentInput(BaseModel):
|
|||||||
schedule_name: str = ""
|
schedule_name: str = ""
|
||||||
cron: str = ""
|
cron: str = ""
|
||||||
timezone: str = "UTC"
|
timezone: str = "UTC"
|
||||||
wait_for_result: int = Field(default=0, ge=0, le=300)
|
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"username_agent_slug",
|
"username_agent_slug",
|
||||||
@@ -152,14 +149,6 @@ class RunAgentTool(BaseTool):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "IANA timezone for schedule (default: UTC)",
|
"description": "IANA timezone for schedule (default: UTC)",
|
||||||
},
|
},
|
||||||
"wait_for_result": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": (
|
|
||||||
"Max seconds to wait for execution to complete (0-300). "
|
|
||||||
"If >0, blocks until the execution finishes or times out. "
|
|
||||||
"Returns execution outputs when complete."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"required": [],
|
"required": [],
|
||||||
}
|
}
|
||||||
@@ -209,7 +198,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
# Priority: library_agent_id if provided
|
# Priority: library_agent_id if provided
|
||||||
if has_library_id:
|
if has_library_id:
|
||||||
library_agent = await library_db().get_library_agent(
|
library_agent = await library_db.get_library_agent(
|
||||||
params.library_agent_id, user_id
|
params.library_agent_id, user_id
|
||||||
)
|
)
|
||||||
if not library_agent:
|
if not library_agent:
|
||||||
@@ -218,7 +207,9 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
# Get the graph from the library agent
|
# Get the graph from the library agent
|
||||||
graph = await graph_db().get_graph(
|
from backend.data.graph import get_graph
|
||||||
|
|
||||||
|
graph = await get_graph(
|
||||||
library_agent.graph_id,
|
library_agent.graph_id,
|
||||||
library_agent.graph_version,
|
library_agent.graph_version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -269,7 +260,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
"inputs": self._get_inputs_list(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -282,22 +273,6 @@ class RunAgentTool(BaseTool):
|
|||||||
input_properties = graph.input_schema.get("properties", {})
|
input_properties = graph.input_schema.get("properties", {})
|
||||||
required_fields = set(graph.input_schema.get("required", []))
|
required_fields = set(graph.input_schema.get("required", []))
|
||||||
provided_inputs = set(params.inputs.keys())
|
provided_inputs = set(params.inputs.keys())
|
||||||
valid_fields = set(input_properties.keys())
|
|
||||||
|
|
||||||
# Check for unknown input fields
|
|
||||||
unrecognized_fields = provided_inputs - valid_fields
|
|
||||||
if unrecognized_fields:
|
|
||||||
return InputValidationErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
|
||||||
f"Agent was not executed. Please use the correct field names from the schema."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
unrecognized_fields=sorted(unrecognized_fields),
|
|
||||||
inputs=graph.input_schema,
|
|
||||||
graph_id=graph.id,
|
|
||||||
graph_version=graph.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||||
# always show what's available first so user can decide
|
# always show what's available first so user can decide
|
||||||
@@ -354,7 +329,6 @@ class RunAgentTool(BaseTool):
|
|||||||
graph=graph,
|
graph=graph,
|
||||||
graph_credentials=graph_credentials,
|
graph_credentials=graph_credentials,
|
||||||
inputs=params.inputs,
|
inputs=params.inputs,
|
||||||
wait_for_result=params.wait_for_result,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError as e:
|
except NotFoundError as e:
|
||||||
@@ -378,6 +352,22 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Extract inputs list from schema."""
|
||||||
|
inputs_list = []
|
||||||
|
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||||
|
for field_name, field_schema in input_schema["properties"].items():
|
||||||
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in input_schema.get("required", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return inputs_list
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -391,7 +381,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
@@ -438,9 +428,8 @@ class RunAgentTool(BaseTool):
|
|||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
graph_credentials: dict[str, CredentialsMetaInput],
|
graph_credentials: dict[str, CredentialsMetaInput],
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
wait_for_result: int = 0,
|
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute an agent immediately, optionally waiting for completion."""
|
"""Execute an agent immediately."""
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
# Check rate limits
|
# Check rate limits
|
||||||
@@ -477,93 +466,6 @@ class RunAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
library_agent_link = f"/library/agents/{library_agent.id}"
|
library_agent_link = f"/library/agents/{library_agent.id}"
|
||||||
|
|
||||||
# If wait_for_result is requested, wait for execution to complete
|
|
||||||
if wait_for_result > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Waiting up to {wait_for_result}s for execution {execution.id}"
|
|
||||||
)
|
|
||||||
completed = await wait_for_execution(
|
|
||||||
user_id=user_id,
|
|
||||||
graph_id=library_agent.graph_id,
|
|
||||||
execution_id=execution.id,
|
|
||||||
timeout_seconds=wait_for_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
if completed and completed.status == ExecutionStatus.COMPLETED:
|
|
||||||
outputs = get_execution_outputs(completed)
|
|
||||||
return AgentOutputResponse(
|
|
||||||
message=(
|
|
||||||
f"Agent '{library_agent.name}' completed successfully. "
|
|
||||||
f"View at {library_agent_link}."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
agent_name=library_agent.name,
|
|
||||||
agent_id=library_agent.graph_id,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=library_agent_link,
|
|
||||||
execution=ExecutionOutputInfo(
|
|
||||||
execution_id=execution.id,
|
|
||||||
status=completed.status.value,
|
|
||||||
started_at=completed.started_at,
|
|
||||||
ended_at=completed.ended_at,
|
|
||||||
outputs=outputs or {},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif completed and completed.status == ExecutionStatus.FAILED:
|
|
||||||
error_detail = completed.stats.error if completed.stats else None
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Agent '{library_agent.name}' execution failed. "
|
|
||||||
f"View details at {library_agent_link}."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
error=error_detail,
|
|
||||||
)
|
|
||||||
elif completed and completed.status == ExecutionStatus.TERMINATED:
|
|
||||||
error_detail = completed.stats.error if completed.stats else None
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Agent '{library_agent.name}' execution was terminated. "
|
|
||||||
f"View details at {library_agent_link}."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
error=error_detail,
|
|
||||||
)
|
|
||||||
elif completed and completed.status == ExecutionStatus.REVIEW:
|
|
||||||
return ExecutionStartedResponse(
|
|
||||||
message=(
|
|
||||||
f"Agent '{library_agent.name}' is awaiting human review. "
|
|
||||||
f"The user can approve or reject inline. After approval, "
|
|
||||||
f"the execution resumes automatically. Use view_agent_output "
|
|
||||||
f"with execution_id='{execution.id}' to check the result."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
execution_id=execution.id,
|
|
||||||
graph_id=library_agent.graph_id,
|
|
||||||
graph_name=library_agent.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=library_agent_link,
|
|
||||||
status=ExecutionStatus.REVIEW.value,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
status = completed.status.value if completed else "unknown"
|
|
||||||
return ExecutionStartedResponse(
|
|
||||||
message=(
|
|
||||||
f"Agent '{library_agent.name}' is still {status} after "
|
|
||||||
f"{wait_for_result}s. Check results later at "
|
|
||||||
f"{library_agent_link}. "
|
|
||||||
f"Use view_agent_output with wait_if_running to check again."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
execution_id=execution.id,
|
|
||||||
graph_id=library_agent.graph_id,
|
|
||||||
graph_name=library_agent.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=library_agent_link,
|
|
||||||
status=status,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ExecutionStartedResponse(
|
return ExecutionStartedResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Agent '{library_agent.name}' execution started successfully. "
|
f"Agent '{library_agent.name}' execution started successfully. "
|
||||||
@@ -618,7 +520,7 @@ class RunAgentTool(BaseTool):
|
|||||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||||
|
|
||||||
# Get user timezone
|
# Get user timezone
|
||||||
user = await user_db().get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||||
|
|
||||||
# Create schedule
|
# Create schedule
|
||||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
|||||||
# Should return error about missing schedule_name
|
# Should return error about missing schedule_name
|
||||||
assert result_data.get("type") == "error"
|
assert result_data.get("type") == "error"
|
||||||
assert "schedule_name" in result_data["message"].lower()
|
assert "schedule_name" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
|
||||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
|
||||||
user = setup_test_data["user"]
|
|
||||||
store_submission = setup_test_data["store_submission"]
|
|
||||||
|
|
||||||
tool = RunAgentTool()
|
|
||||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
|
||||||
session = make_session(user_id=user.id)
|
|
||||||
|
|
||||||
# Execute with unknown input field names
|
|
||||||
response = await tool.execute(
|
|
||||||
user_id=user.id,
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
tool_call_id=str(uuid.uuid4()),
|
|
||||||
username_agent_slug=agent_marketplace_id,
|
|
||||||
inputs={
|
|
||||||
"unknown_field": "some value",
|
|
||||||
"another_unknown": "another value",
|
|
||||||
},
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response is not None
|
|
||||||
assert hasattr(response, "output")
|
|
||||||
assert isinstance(response.output, str)
|
|
||||||
result_data = orjson.loads(response.output)
|
|
||||||
|
|
||||||
# Should return input_validation_error type with unrecognized fields
|
|
||||||
assert result_data.get("type") == "input_validation_error"
|
|
||||||
assert "unrecognized_fields" in result_data
|
|
||||||
assert set(result_data["unrecognized_fields"]) == {
|
|
||||||
"another_unknown",
|
|
||||||
"unknown_field",
|
|
||||||
}
|
|
||||||
assert "inputs" in result_data # Contains the valid schema
|
|
||||||
assert "Agent was not executed" in result_data["message"]
|
|
||||||
@@ -0,0 +1,346 @@
|
|||||||
|
"""Tool for executing blocks directly."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.block import get_block
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
BlockOutputResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
SetupInfo,
|
||||||
|
SetupRequirementsResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
UserReadiness,
|
||||||
|
)
|
||||||
|
from .utils import build_missing_credentials_from_field_info
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RunBlockTool(BaseTool):
|
||||||
|
"""Tool for executing a block and returning its outputs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "run_block"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Execute a specific block with the provided input data. "
|
||||||
|
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||||
|
"do NOT guess or make up block IDs. "
|
||||||
|
"Use the 'id' from find_block results and provide input_data "
|
||||||
|
"matching the block's required_inputs."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"block_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The block's 'id' field from find_block results. "
|
||||||
|
"NEVER guess this - always get it from find_block first."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"input_data": {
|
||||||
|
"type": "object",
|
||||||
|
"description": (
|
||||||
|
"Input values for the block. Use the 'required_inputs' field "
|
||||||
|
"from find_block to see what fields are needed."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["block_id", "input_data"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _check_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: Any,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials, missing_credentials]
|
||||||
|
"""
|
||||||
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
# Get credential field info from block's input schema
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
|
# Get user's available credentials
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
# field_info.provider is a frozenset of acceptable providers
|
||||||
|
# field_info.supported_types is a frozenset of acceptable types
|
||||||
|
matching_cred = next(
|
||||||
|
(
|
||||||
|
cred
|
||||||
|
for cred in available_creds
|
||||||
|
if cred.provider in field_info.provider
|
||||||
|
and cred.type in field_info.supported_types
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
matched_credentials[field_name] = CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create a placeholder for the missing credential
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing_credentials.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute a block with the given input data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID (required)
|
||||||
|
session: Chat session
|
||||||
|
block_id: Block UUID to execute
|
||||||
|
input_data: Input values for the block
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BlockOutputResponse: Block execution outputs
|
||||||
|
SetupRequirementsResponse: Missing credentials
|
||||||
|
ErrorResponse: Error message
|
||||||
|
"""
|
||||||
|
block_id = kwargs.get("block_id", "").strip()
|
||||||
|
input_data = kwargs.get("input_data", {})
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not block_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a block_id",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(input_data, dict):
|
||||||
|
return ErrorResponse(
|
||||||
|
message="input_data must be an object",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the block
|
||||||
|
block = get_block(block_id)
|
||||||
|
if not block:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Block '{block_id}' not found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
if block.disabled:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Block '{block_id}' is disabled",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
|
# Check credentials
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
|
user_id, block
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing_credentials:
|
||||||
|
# Return setup requirements response with missing credentials
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||||
|
credentials_fields_info, set(matched_credentials.keys())
|
||||||
|
)
|
||||||
|
missing_creds_list = list(missing_creds_dict.values())
|
||||||
|
|
||||||
|
return SetupRequirementsResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' requires credentials that are not configured. "
|
||||||
|
"Please set up the required credentials before running this block."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
setup_info=SetupInfo(
|
||||||
|
agent_id=block_id,
|
||||||
|
agent_name=block.name,
|
||||||
|
user_readiness=UserReadiness(
|
||||||
|
has_all_credentials=False,
|
||||||
|
missing_credentials=missing_creds_dict,
|
||||||
|
ready_to_run=False,
|
||||||
|
),
|
||||||
|
requirements={
|
||||||
|
"credentials": missing_creds_list,
|
||||||
|
"inputs": self._get_inputs_list(block),
|
||||||
|
"execution_modes": ["immediate"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
graph_id=None,
|
||||||
|
graph_version=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get or create user's workspace for CoPilot file operations
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
# Generate synthetic IDs for CoPilot context
|
||||||
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
|
# This means:
|
||||||
|
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||||
|
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||||
|
# - node_exec_id = unique per block execution
|
||||||
|
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_node_id = f"copilot-node-{block_id}"
|
||||||
|
synthetic_node_exec_id = (
|
||||||
|
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unified execution context with all required fields
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=synthetic_graph_id,
|
||||||
|
graph_exec_id=synthetic_graph_exec_id,
|
||||||
|
graph_version=1, # Versions are 1-indexed
|
||||||
|
node_id=synthetic_node_id,
|
||||||
|
node_exec_id=synthetic_node_exec_id,
|
||||||
|
# Workspace with session scoping
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare kwargs for block execution
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
|
exec_kwargs: dict[str, Any] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"execution_context": execution_context,
|
||||||
|
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||||
|
"workspace_id": workspace.id,
|
||||||
|
"graph_exec_id": synthetic_graph_exec_id,
|
||||||
|
"node_exec_id": synthetic_node_exec_id,
|
||||||
|
"node_id": synthetic_node_id,
|
||||||
|
"graph_version": 1, # Versions are 1-indexed
|
||||||
|
"graph_id": synthetic_graph_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
# Inject metadata into input_data (for validation)
|
||||||
|
if field_name not in input_data:
|
||||||
|
input_data[field_name] = cred_meta.model_dump()
|
||||||
|
|
||||||
|
# Fetch actual credentials and pass as kwargs (for execution)
|
||||||
|
actual_credentials = await creds_manager.get(
|
||||||
|
user_id, cred_meta.id, lock=False
|
||||||
|
)
|
||||||
|
if actual_credentials:
|
||||||
|
exec_kwargs[field_name] = actual_credentials
|
||||||
|
else:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to retrieve credentials for {field_name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the block and collect outputs
|
||||||
|
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||||
|
async for output_name, output_data in block.execute(
|
||||||
|
input_data,
|
||||||
|
**exec_kwargs,
|
||||||
|
):
|
||||||
|
outputs[output_name].append(output_data)
|
||||||
|
|
||||||
|
return BlockOutputResponse(
|
||||||
|
message=f"Block '{block.name}' executed successfully",
|
||||||
|
block_id=block_id,
|
||||||
|
block_name=block.name,
|
||||||
|
outputs=dict(outputs),
|
||||||
|
success=True,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except BlockError as e:
|
||||||
|
logger.warning(f"Block execution failed: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Block execution failed: {e}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to execute block: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
|
"""Extract non-credential inputs from block schema."""
|
||||||
|
inputs_list = []
|
||||||
|
schema = block.input_schema.jsonschema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = set(schema.get("required", []))
|
||||||
|
|
||||||
|
# Get credential field names to exclude
|
||||||
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
# Skip credential fields
|
||||||
|
if field_name in credentials_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inputs_list.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"title": field_schema.get("title", field_name),
|
||||||
|
"type": field_schema.get("type", "string"),
|
||||||
|
"description": field_schema.get("description", ""),
|
||||||
|
"required": field_name in required_fields,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs_list
|
||||||
@@ -5,17 +5,16 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.copilot.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.db_accessors import search
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
from backend.api.features.chat.tools.models import (
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
DocSearchResult,
|
DocSearchResult,
|
||||||
DocSearchResultsResponse,
|
DocSearchResultsResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Search using hybrid search for DOCUMENTATION content type only
|
# Search using hybrid search for DOCUMENTATION content type only
|
||||||
results, total = await search().unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.DOCUMENTATION],
|
content_types=[ContentType.DOCUMENTATION],
|
||||||
page=1,
|
page=1,
|
||||||
@@ -3,18 +3,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.data.db_accessors import library_db, store_db
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
Credentials,
|
|
||||||
CredentialsFieldInfo,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
HostScopedCredentials,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -38,15 +33,20 @@ async def fetch_graph_from_store_slug(
|
|||||||
Raises:
|
Raises:
|
||||||
DatabaseError: If there's a database error during lookup.
|
DatabaseError: If there's a database error during lookup.
|
||||||
"""
|
"""
|
||||||
sdb = store_db()
|
|
||||||
try:
|
try:
|
||||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await sdb.get_available_graph(
|
graph_meta = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id
|
||||||
|
)
|
||||||
|
graph = await graph_db.get_graph(
|
||||||
|
graph_id=graph_meta.id,
|
||||||
|
version=graph_meta.version,
|
||||||
|
user_id=None, # Public access
|
||||||
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
|
|||||||
Returns:
|
Returns:
|
||||||
LibraryAgent instance
|
LibraryAgent instance
|
||||||
"""
|
"""
|
||||||
existing = await library_db().get_library_agent_by_graph_id(
|
existing = await library_db.get_library_agent_by_graph_id(
|
||||||
graph_id=graph.id, user_id=user_id
|
graph_id=graph.id, user_id=user_id
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
return existing
|
return existing
|
||||||
|
|
||||||
library_agents = await library_db().create_library_agent(
|
library_agents = await library_db.create_library_agent(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
create_library_agents_for_sub_graphs=False,
|
create_library_agents_for_sub_graphs=False,
|
||||||
@@ -225,99 +225,6 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
async def match_credentials_to_requirements(
|
|
||||||
user_id: str,
|
|
||||||
requirements: dict[str, CredentialsFieldInfo],
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Match user's credentials against a dictionary of credential requirements.
|
|
||||||
|
|
||||||
This is the core matching logic shared by both graph and block credential matching.
|
|
||||||
"""
|
|
||||||
matched: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
available_creds = await get_user_credentials(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in requirements.items():
|
|
||||||
matching_cred = find_matching_credential(available_creds, field_info)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
try:
|
|
||||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=f"{field_name} (validation failed: {e})",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched, missing
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
|
||||||
"""Get all available credentials for a user."""
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
return await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def find_matching_credential(
|
|
||||||
available_creds: list[Credentials],
|
|
||||||
field_info: CredentialsFieldInfo,
|
|
||||||
) -> Credentials | None:
|
|
||||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
|
||||||
for cred in available_creds:
|
|
||||||
if cred.provider not in field_info.provider:
|
|
||||||
continue
|
|
||||||
if cred.type not in field_info.supported_types:
|
|
||||||
continue
|
|
||||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
|
||||||
cred, field_info
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
|
||||||
continue
|
|
||||||
return cred
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_credential_meta_from_match(
|
|
||||||
matching_cred: Credentials,
|
|
||||||
) -> CredentialsMetaInput:
|
|
||||||
"""Create a CredentialsMetaInput from a matched credential."""
|
|
||||||
return CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -357,28 +264,16 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_,
|
_node_fields,
|
||||||
_,
|
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, scopes, and host/URL
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
and (
|
and _credential_has_required_scopes(cred, credential_requirements)
|
||||||
cred.type != "oauth2"
|
|
||||||
or _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.type != "host_scoped"
|
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.provider != ProviderName.MCP
|
|
||||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -423,46 +318,25 @@ async def match_user_credentials_to_graph(
|
|||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: OAuth2Credentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
|
||||||
# If no scopes are required, any credential matches
|
|
||||||
if not requirements.required_scopes:
|
|
||||||
return True
|
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_host(
|
|
||||||
credential: HostScopedCredentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a host-scoped credential matches the host required by the input."""
|
|
||||||
# We need to know the host to match host-scoped credentials to.
|
|
||||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
|
||||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check that credential host matches required host.
|
|
||||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
|
||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_mcp_server(
|
|
||||||
credential: Credentials,
|
credential: Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
"""
|
||||||
if not requirements.discriminator_values:
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
server_url = (
|
# If no scopes are required, any credential matches
|
||||||
credential.metadata.get("mcp_server_url")
|
if not requirements.required_scopes:
|
||||||
if isinstance(credential, OAuth2Credentials)
|
return True
|
||||||
else None
|
|
||||||
)
|
# Check that credential scopes are a superset of required scopes
|
||||||
return server_url in requirements.discriminator_values if server_url else False
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
@@ -0,0 +1,620 @@
|
|||||||
|
"""CoPilot tools for workspace file operations."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileInfoData(BaseModel):
|
||||||
|
"""Data model for workspace file information (not a response itself)."""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileListResponse(ToolResponseBase):
|
||||||
|
"""Response containing list of workspace files."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||||
|
files: list[WorkspaceFileInfoData]
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file content (legacy, for small text files)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
content_base64: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
download_url: str
|
||||||
|
preview: str | None = None # First 500 chars for text files
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceWriteResponse(ToolResponseBase):
|
||||||
|
"""Response after writing a file to workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||||
|
"""Response after deleting a file from workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||||
|
file_id: str
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ListWorkspaceFilesTool(BaseTool):
|
||||||
|
"""Tool for listing files in user's workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_workspace_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"List files in the user's workspace. "
|
||||||
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
|
"Optionally filter by path prefix."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path_prefix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional path prefix to filter files "
|
||||||
|
"(e.g., '/documents/' to list only files in documents folder). "
|
||||||
|
"By default, only files from the current session are listed."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of files to return (default 50, max 100)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 100,
|
||||||
|
},
|
||||||
|
"include_all_sessions": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, list files from all sessions. "
|
||||||
|
"Default is false (only current session's files)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||||
|
limit = min(kwargs.get("limit", 50), 100)
|
||||||
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
files = await manager.list_files(
|
||||||
|
path=path_prefix,
|
||||||
|
limit=limit,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
total = await manager.get_file_count(
|
||||||
|
path=path_prefix,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_infos = [
|
||||||
|
WorkspaceFileInfoData(
|
||||||
|
file_id=f.id,
|
||||||
|
name=f.name,
|
||||||
|
path=f.path,
|
||||||
|
mime_type=f.mimeType,
|
||||||
|
size_bytes=f.sizeBytes,
|
||||||
|
)
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||||
|
return WorkspaceFileListResponse(
|
||||||
|
files=file_infos,
|
||||||
|
total_count=total,
|
||||||
|
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to list workspace files: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for reading file content from workspace."""
|
||||||
|
|
||||||
|
# Size threshold for returning full content vs metadata+URL
|
||||||
|
# Files larger than this return metadata with download URL to prevent context bloat
|
||||||
|
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||||
|
# Preview size for text files
|
||||||
|
PREVIEW_SIZE = 500
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "read_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Read a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"For small text files, returns content directly. "
|
||||||
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"force_download_url": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, always return metadata+URL instead of inline content. "
|
||||||
|
"Default is false (auto-selects based on file size/type)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||||
|
"""Check if the MIME type is a text-based type."""
|
||||||
|
text_types = [
|
||||||
|
"text/",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/javascript",
|
||||||
|
"application/x-python",
|
||||||
|
"application/x-sh",
|
||||||
|
]
|
||||||
|
return any(mime_type.startswith(t) for t in text_types)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
if file_id:
|
||||||
|
file_info = await manager.get_file_info(file_id)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
# Decide whether to return inline content or metadata+URL
|
||||||
|
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
|
# Return inline content for small text files (unless force_download_url)
|
||||||
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
return WorkspaceFileContentResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
content_base64=content_b64,
|
||||||
|
message=f"Successfully read file: {file_info.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return metadata + workspace:// reference for large or binary files
|
||||||
|
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||||
|
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||||
|
download_url = f"workspace://{target_file_id}"
|
||||||
|
|
||||||
|
# Generate preview for text files
|
||||||
|
preview: str | None = None
|
||||||
|
if is_text_file:
|
||||||
|
try:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
if len(content) > self.PREVIEW_SIZE:
|
||||||
|
preview_text += "..."
|
||||||
|
preview = preview_text
|
||||||
|
except Exception:
|
||||||
|
pass # Preview is optional
|
||||||
|
|
||||||
|
return WorkspaceFileMetadataResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
size_bytes=file_info.sizeBytes,
|
||||||
|
download_url=download_url,
|
||||||
|
preview=preview,
|
||||||
|
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to read workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for writing files to workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "write_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Write or create a file in the user's workspace. "
|
||||||
|
"Provide the content as a base64-encoded string. "
|
||||||
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
|
"Files are saved to the current session's folder by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"filename": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name for the file (e.g., 'report.pdf')",
|
||||||
|
},
|
||||||
|
"content_base64": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Base64-encoded file content",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional virtual path where to save the file "
|
||||||
|
"(e.g., '/documents/report.pdf'). "
|
||||||
|
"Defaults to '/{filename}'. Scoped to current session."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional MIME type of the file. "
|
||||||
|
"Auto-detected from filename if not provided."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"overwrite": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["filename", "content_base64"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename: str = kwargs.get("filename", "")
|
||||||
|
content_b64: str = kwargs.get("content_base64", "")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||||
|
overwrite: bool = kwargs.get("overwrite", False)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a filename",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content_b64:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide content_base64",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode content
|
||||||
|
try:
|
||||||
|
content = base64.b64decode(content_b64)
|
||||||
|
except Exception:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Invalid base64-encoded content",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check size
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Virus scan
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
file_record = await manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
path=path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
overwrite=overwrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceWriteResponse(
|
||||||
|
file_id=file_record.id,
|
||||||
|
name=file_record.name,
|
||||||
|
path=file_record.path,
|
||||||
|
size_bytes=file_record.sizeBytes,
|
||||||
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to write workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for deleting files from workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "delete_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Delete a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Determine the file_id to delete
|
||||||
|
target_file_id: str
|
||||||
|
if file_id:
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
success = await manager.delete_file(target_file_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {target_file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceDeleteResponse(
|
||||||
|
file_id=target_file_id,
|
||||||
|
success=True,
|
||||||
|
message="File deleted successfully",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to delete workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.api.features.executions.review.routes.get_node_executions"
|
"backend.data.execution.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "test_node_123"
|
mock_node_exec.node_exec_id = "test_node_123"
|
||||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.api.features.executions.review.routes.get_node_executions"
|
"backend.data.execution.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
|||||||
|
|
||||||
# Mock get_node_executions to return batch node data
|
# Mock get_node_executions to return batch node data
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.api.features.executions.review.routes.get_node_executions"
|
"backend.data.execution.get_node_executions"
|
||||||
)
|
)
|
||||||
# Create mock node executions for each review
|
# Create mock node executions for each review
|
||||||
mock_node_execs = []
|
mock_node_execs = []
|
||||||
|
|||||||
@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
|
|||||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.copilot.constants import (
|
|
||||||
is_copilot_synthetic_id,
|
|
||||||
parse_node_id_from_exec_id,
|
|
||||||
)
|
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionContext,
|
ExecutionContext,
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
get_node_executions,
|
|
||||||
)
|
)
|
||||||
from backend.data.graph import get_graph_settings
|
from backend.data.graph import get_graph_settings
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
@@ -27,7 +22,6 @@ from backend.data.human_review import (
|
|||||||
)
|
)
|
||||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||||
from backend.data.user import get_user_by_id
|
from backend.data.user import get_user_by_id
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
|
|
||||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||||
@@ -41,38 +35,6 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_node_ids(
|
|
||||||
node_exec_ids: list[str],
|
|
||||||
graph_exec_id: str,
|
|
||||||
is_copilot: bool,
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Resolve node_exec_id -> node_id for auto-approval records.
|
|
||||||
|
|
||||||
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
|
||||||
Graph executions look up node_id from NodeExecution records.
|
|
||||||
"""
|
|
||||||
if not node_exec_ids:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if is_copilot:
|
|
||||||
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
|
||||||
|
|
||||||
node_execs = await get_node_executions(
|
|
||||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
|
||||||
)
|
|
||||||
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for neid in node_exec_ids:
|
|
||||||
if neid in node_exec_map:
|
|
||||||
result[neid] = node_exec_map[neid]
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to resolve node_id for {neid}: Node execution not found."
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/pending",
|
"/pending",
|
||||||
summary="Get Pending Reviews",
|
summary="Get Pending Reviews",
|
||||||
@@ -147,16 +109,14 @@ async def list_pending_reviews_for_execution(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Verify user owns the graph execution before returning reviews
|
# Verify user owns the graph execution before returning reviews
|
||||||
# (CoPilot synthetic IDs don't have graph execution records)
|
graph_exec = await get_graph_execution_meta(
|
||||||
if not is_copilot_synthetic_id(graph_exec_id):
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
graph_exec = await get_graph_execution_meta(
|
)
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
if not graph_exec:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
)
|
)
|
||||||
if not graph_exec:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||||
|
|
||||||
@@ -199,26 +159,30 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph_exec_id = next(iter(graph_exec_ids))
|
graph_exec_id = next(iter(graph_exec_ids))
|
||||||
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
|
||||||
|
|
||||||
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
# Validate execution status before processing reviews
|
||||||
if not is_copilot:
|
graph_exec_meta = await get_graph_execution_meta(
|
||||||
graph_exec_meta = await get_graph_execution_meta(
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
)
|
||||||
|
|
||||||
|
if not graph_exec_meta:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only allow processing reviews if execution is paused for review
|
||||||
|
# or incomplete (partial execution with some reviews already processed)
|
||||||
|
if graph_exec_meta.status not in (
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||||
|
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||||
|
f"Current status: {graph_exec_meta.status}",
|
||||||
)
|
)
|
||||||
if not graph_exec_meta:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
|
||||||
if graph_exec_meta.status not in (
|
|
||||||
ExecutionStatus.REVIEW,
|
|
||||||
ExecutionStatus.INCOMPLETE,
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build review decisions map and track which reviews requested auto-approval
|
# Build review decisions map and track which reviews requested auto-approval
|
||||||
# Auto-approved reviews use original data (no modifications allowed)
|
# Auto-approved reviews use original data (no modifications allowed)
|
||||||
@@ -271,7 +235,7 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
return (node_id, False)
|
return (node_id, False)
|
||||||
|
|
||||||
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
# Collect node_exec_ids that need auto-approval
|
||||||
node_exec_ids_needing_auto_approval = [
|
node_exec_ids_needing_auto_approval = [
|
||||||
node_exec_id
|
node_exec_id
|
||||||
for node_exec_id, review_result in updated_reviews.items()
|
for node_exec_id, review_result in updated_reviews.items()
|
||||||
@@ -279,16 +243,29 @@ async def process_review_action(
|
|||||||
and auto_approve_requests.get(node_exec_id, False)
|
and auto_approve_requests.get(node_exec_id, False)
|
||||||
]
|
]
|
||||||
|
|
||||||
node_id_map = await _resolve_node_ids(
|
# Batch-fetch node executions to get node_ids
|
||||||
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deduplicate by node_id — one auto-approval per node
|
|
||||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
if node_exec_ids_needing_auto_approval:
|
||||||
node_id = node_id_map.get(node_exec_id)
|
from backend.data.execution import get_node_executions
|
||||||
if node_id and node_id not in nodes_needing_auto_approval:
|
|
||||||
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
node_execs = await get_node_executions(
|
||||||
|
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||||
|
)
|
||||||
|
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||||
|
|
||||||
|
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||||
|
node_exec = node_exec_map.get(node_exec_id)
|
||||||
|
if node_exec:
|
||||||
|
review_result = updated_reviews[node_exec_id]
|
||||||
|
# Use the first approved review for this node (deduplicate by node_id)
|
||||||
|
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||||
|
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||||
|
f"Node execution not found. This may indicate a race condition "
|
||||||
|
f"or data inconsistency."
|
||||||
|
)
|
||||||
|
|
||||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||||
auto_approval_results = await asyncio.gather(
|
auto_approval_results = await asyncio.gather(
|
||||||
@@ -303,11 +280,13 @@ async def process_review_action(
|
|||||||
auto_approval_failed_count = 0
|
auto_approval_failed_count = 0
|
||||||
for result in auto_approval_results:
|
for result in auto_approval_results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
# Unexpected exception during auto-approval creation
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unexpected exception during auto-approval creation: {result}"
|
f"Unexpected exception during auto-approval creation: {result}"
|
||||||
)
|
)
|
||||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||||
|
# Auto-approval creation failed (returned False)
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
|
|
||||||
# Count results
|
# Count results
|
||||||
@@ -322,31 +301,30 @@ async def process_review_action(
|
|||||||
if review.status == ReviewStatus.REJECTED
|
if review.status == ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resume graph execution only for real graph executions (not CoPilot)
|
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||||
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
if updated_reviews:
|
||||||
if not is_copilot and updated_reviews:
|
|
||||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||||
|
|
||||||
if not still_has_pending:
|
if not still_has_pending:
|
||||||
|
# Get the graph_id from any processed review
|
||||||
first_review = next(iter(updated_reviews.values()))
|
first_review = next(iter(updated_reviews.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Fetch user and settings to build complete execution context
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
settings = await get_graph_settings(
|
settings = await get_graph_settings(
|
||||||
user_id=user_id, graph_id=first_review.graph_id
|
user_id=user_id, graph_id=first_review.graph_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Preserve user's timezone preference when resuming execution
|
||||||
user_timezone = (
|
user_timezone = (
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
user_timezone=user_timezone,
|
user_timezone=user_timezone,
|
||||||
workspace_id=workspace.id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_graph_execution(
|
await add_graph_execution(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
from typing import TYPE_CHECKING, Annotated, List, Literal
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
|||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
from backend.integrations.credentials_store import provider_matches
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.creds_manager import (
|
|
||||||
IntegrationCredentialsManager,
|
|
||||||
create_mcp_oauth_handler,
|
|
||||||
)
|
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None,
|
default=None, description="Host pattern for host-scoped credentials"
|
||||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_provider(cls, data: Any) -> Any:
|
|
||||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
prov = data.get("provider", "")
|
|
||||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
|
||||||
member = prov.removeprefix("ProviderName.")
|
|
||||||
try:
|
|
||||||
data = {**data, "provider": ProviderName[member].value}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_host(cred: Credentials) -> str | None:
|
|
||||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
|
||||||
if isinstance(cred, HostScopedCredentials):
|
|
||||||
return cred.host
|
|
||||||
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
|
||||||
ProviderName.MCP,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
"ProviderName.MCP",
|
|
||||||
):
|
|
||||||
return (cred.metadata or {}).get("mcp_server_url")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -211,7 +179,9 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
host=(
|
||||||
|
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -229,7 +199,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -352,11 +322,7 @@ async def delete_credentials(
|
|||||||
|
|
||||||
tokens_revoked = None
|
tokens_revoked = None
|
||||||
if isinstance(creds, OAuth2Credentials):
|
if isinstance(creds, OAuth2Credentials):
|
||||||
if provider_matches(provider.value, ProviderName.MCP.value):
|
handler = _get_provider_oauth_handler(request, provider)
|
||||||
# MCP uses dynamic per-server OAuth — create handler from metadata
|
|
||||||
handler = create_mcp_oauth_handler(creds)
|
|
||||||
else:
|
|
||||||
handler = _get_provider_oauth_handler(request, provider)
|
|
||||||
tokens_revoked = await handler.revoke_tokens(creds)
|
tokens_revoked = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import backend.api.features.store.exceptions
|
||||||
from backend.data.db import connect
|
from backend.data.db import connect
|
||||||
from backend.data.includes import library_agent_include
|
from backend.data.includes import library_agent_include
|
||||||
|
|
||||||
@@ -143,7 +144,6 @@ async def test_add_agent_to_library(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
|
||||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||||
return_value=mock_library_agent_data
|
return_value=mock_library_agent_data
|
||||||
@@ -178,6 +178,7 @@ async def test_add_agent_to_library(mocker):
|
|||||||
"agentGraphVersion": 1,
|
"agentGraphVersion": 1,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
include={"AgentGraph": True},
|
||||||
)
|
)
|
||||||
# Check that create was called with the expected data including settings
|
# Check that create was called with the expected data including settings
|
||||||
create_call_args = mock_library_agent.return_value.create.call_args
|
create_call_args = mock_library_agent.return_value.create.call_args
|
||||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Call function and verify exception
|
# Call function and verify exception
|
||||||
with pytest.raises(db.NotFoundError):
|
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||||
await db.add_store_agent_to_library("version123", "test-user")
|
await db.add_store_agent_to_library("version123", "test-user")
|
||||||
|
|
||||||
# Verify mock called correctly
|
# Verify mock called correctly
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
class FolderValidationError(Exception):
|
|
||||||
"""Raised when folder operations fail validation."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FolderAlreadyExistsError(FolderValidationError):
|
|
||||||
"""Raised when a folder with the same name already exists in the location."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
@@ -6,12 +6,9 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
from backend.data.block import BlockInput
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import (
|
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||||
CredentialsMetaInput,
|
|
||||||
GraphInput,
|
|
||||||
is_credentials_field_name,
|
|
||||||
)
|
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -26,95 +23,6 @@ class LibraryAgentStatus(str, Enum):
|
|||||||
ERROR = "ERROR"
|
ERROR = "ERROR"
|
||||||
|
|
||||||
|
|
||||||
# === Folder Models ===
|
|
||||||
|
|
||||||
|
|
||||||
class LibraryFolder(pydantic.BaseModel):
|
|
||||||
"""Represents a folder for organizing library agents."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
name: str
|
|
||||||
icon: str | None = None
|
|
||||||
color: str | None = None
|
|
||||||
parent_id: str | None = None
|
|
||||||
created_at: datetime.datetime
|
|
||||||
updated_at: datetime.datetime
|
|
||||||
agent_count: int = 0 # Direct agents in folder
|
|
||||||
subfolder_count: int = 0 # Direct child folders
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(
|
|
||||||
folder: prisma.models.LibraryFolder,
|
|
||||||
agent_count: int = 0,
|
|
||||||
subfolder_count: int = 0,
|
|
||||||
) -> "LibraryFolder":
|
|
||||||
"""Factory method that constructs a LibraryFolder from a Prisma model."""
|
|
||||||
return LibraryFolder(
|
|
||||||
id=folder.id,
|
|
||||||
user_id=folder.userId,
|
|
||||||
name=folder.name,
|
|
||||||
icon=folder.icon,
|
|
||||||
color=folder.color,
|
|
||||||
parent_id=folder.parentId,
|
|
||||||
created_at=folder.createdAt,
|
|
||||||
updated_at=folder.updatedAt,
|
|
||||||
agent_count=agent_count,
|
|
||||||
subfolder_count=subfolder_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LibraryFolderTree(LibraryFolder):
|
|
||||||
"""Folder with nested children for tree view."""
|
|
||||||
|
|
||||||
children: list["LibraryFolderTree"] = []
|
|
||||||
|
|
||||||
|
|
||||||
class FolderCreateRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for creating a folder."""
|
|
||||||
|
|
||||||
name: str = pydantic.Field(..., min_length=1, max_length=100)
|
|
||||||
icon: str | None = None
|
|
||||||
color: str | None = pydantic.Field(
|
|
||||||
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
|
|
||||||
)
|
|
||||||
parent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class FolderUpdateRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for updating a folder."""
|
|
||||||
|
|
||||||
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
|
|
||||||
icon: str | None = None
|
|
||||||
color: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class FolderMoveRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for moving a folder to a new parent."""
|
|
||||||
|
|
||||||
target_parent_id: str | None = None # None = move to root
|
|
||||||
|
|
||||||
|
|
||||||
class BulkMoveAgentsRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for moving multiple agents to a folder."""
|
|
||||||
|
|
||||||
agent_ids: list[str]
|
|
||||||
folder_id: str | None = None # None = move to root
|
|
||||||
|
|
||||||
|
|
||||||
class FolderListResponse(pydantic.BaseModel):
|
|
||||||
"""Response schema for a list of folders."""
|
|
||||||
|
|
||||||
folders: list[LibraryFolder]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
class FolderTreeResponse(pydantic.BaseModel):
|
|
||||||
"""Response schema for folder tree structure."""
|
|
||||||
|
|
||||||
tree: list[LibraryFolderTree]
|
|
||||||
|
|
||||||
|
|
||||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||||
"""Creator information for a marketplace listing."""
|
"""Creator information for a marketplace listing."""
|
||||||
|
|
||||||
@@ -165,6 +73,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
owner_user_id: str
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -205,14 +114,9 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of recent executions with status, score, and summary",
|
description="List of recent executions with status, score, and summary",
|
||||||
)
|
)
|
||||||
can_access_graph: bool = pydantic.Field(
|
can_access_graph: bool
|
||||||
description="Indicates whether the same user owns the corresponding graph"
|
|
||||||
)
|
|
||||||
is_latest_version: bool
|
is_latest_version: bool
|
||||||
is_favorite: bool
|
is_favorite: bool
|
||||||
folder_id: str | None = None
|
|
||||||
folder_name: str | None = None # Denormalized for display
|
|
||||||
|
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||||
@@ -325,6 +229,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id=agent.id,
|
id=agent.id,
|
||||||
graph_id=agent.agentGraphId,
|
graph_id=agent.agentGraphId,
|
||||||
graph_version=agent.agentGraphVersion,
|
graph_version=agent.agentGraphVersion,
|
||||||
|
owner_user_id=agent.userId,
|
||||||
image_url=agent.imageUrl,
|
image_url=agent.imageUrl,
|
||||||
creator_name=creator_name,
|
creator_name=creator_name,
|
||||||
creator_image_url=creator_image_url,
|
creator_image_url=creator_image_url,
|
||||||
@@ -351,8 +256,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
can_access_graph=can_access_graph,
|
can_access_graph=can_access_graph,
|
||||||
is_latest_version=is_latest_version,
|
is_latest_version=is_latest_version,
|
||||||
is_favorite=agent.isFavorite,
|
is_favorite=agent.isFavorite,
|
||||||
folder_id=agent.folderId,
|
|
||||||
folder_name=agent.Folder.name if agent.Folder else None,
|
|
||||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||||
settings=_parse_settings(agent.settings),
|
settings=_parse_settings(agent.settings),
|
||||||
marketplace_listing=marketplace_listing_data,
|
marketplace_listing=marketplace_listing_data,
|
||||||
@@ -420,7 +323,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
|
||||||
inputs: GraphInput
|
inputs: BlockInput
|
||||||
credentials: dict[str, CredentialsMetaInput]
|
credentials: dict[str, CredentialsMetaInput]
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -449,7 +352,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
|||||||
Request model used when updating a preset for a library agent.
|
Request model used when updating a preset for a library agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: Optional[GraphInput] = None
|
inputs: Optional[BlockInput] = None
|
||||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -492,7 +395,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
|||||||
"Webhook must be included in AgentPreset query when webhookId is set"
|
"Webhook must be included in AgentPreset query when webhookId is set"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data: GraphInput = {}
|
input_data: BlockInput = {}
|
||||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
|
|
||||||
for preset_input in preset.InputPresets:
|
for preset_input in preset.InputPresets:
|
||||||
@@ -564,7 +467,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
|||||||
settings: Optional[GraphSettings] = pydantic.Field(
|
settings: Optional[GraphSettings] = pydantic.Field(
|
||||||
default=None, description="User-specific settings for this library agent"
|
default=None, description="User-specific settings for this library agent"
|
||||||
)
|
)
|
||||||
folder_id: Optional[str] = pydantic.Field(
|
|
||||||
default=None,
|
|
||||||
description="Folder ID to move agent to (None to move to root)",
|
|
||||||
)
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user