mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
7 Commits
fix/copilo
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad1a814724 | ||
|
|
562cf04ab6 | ||
|
|
90b3b5ba16 | ||
|
|
f4f81bc4fc | ||
|
|
c5abc01f25 | ||
|
|
8b7053c1de | ||
|
|
e00c1202ad |
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
@@ -1,35 +0,0 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
@@ -1,16 +0,0 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
@@ -1,31 +0,0 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
@@ -1,51 +0,0 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review Comment Workflow
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
|
||||
## CRITICAL: Do Not Stop
|
||||
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
|
||||
## Rules
|
||||
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
@@ -5,13 +5,42 @@
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
!autogpt_platform/autogpt_libs/poetry.lock
|
||||
!autogpt_platform/autogpt_libs/README.md
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/test/e2e_test_data.py
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
!autogpt_platform/market/scripts.py
|
||||
!autogpt_platform/market/schema.prisma
|
||||
!autogpt_platform/market/pyproject.toml
|
||||
!autogpt_platform/market/poetry.lock
|
||||
!autogpt_platform/market/README.md
|
||||
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
!autogpt_platform/frontend/README.md
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
@@ -35,38 +64,6 @@
|
||||
# Classic - Frontend
|
||||
!classic/frontend/build/web/
|
||||
|
||||
# Explicitly re-ignore unwanted files from whitelisted directories
|
||||
# Note: These patterns MUST come after the whitelist rules to take effect
|
||||
|
||||
# Hidden files and directories (but keep frontend .env files needed for build)
|
||||
**/.*
|
||||
!autogpt_platform/frontend/.env
|
||||
!autogpt_platform/frontend/.env.default
|
||||
!autogpt_platform/frontend/.env.production
|
||||
|
||||
# Python artifacts
|
||||
**/__pycache__/
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/.venv/
|
||||
**/.ruff_cache/
|
||||
**/.pytest_cache/
|
||||
**/.coverage
|
||||
**/htmlcov/
|
||||
|
||||
# Node artifacts
|
||||
**/node_modules/
|
||||
**/.next/
|
||||
**/storybook-static/
|
||||
**/playwright-report/
|
||||
**/test-results/
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
!autogpt_platform/frontend/src/**/build/
|
||||
**/target/
|
||||
|
||||
# Logs and temp files
|
||||
**/*.log
|
||||
**/*.tmp
|
||||
# Explicitly re-ignore some folders
|
||||
.*
|
||||
**/__pycache__
|
||||
|
||||
1229
.github/scripts/detect_overlaps.py
vendored
1229
.github/scripts/detect_overlaps.py
vendored
File diff suppressed because it is too large
Load Diff
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
|
||||
46
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
46
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
@@ -40,51 +40,9 @@ jobs:
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
# Backend Python/Poetry setup (so Claude can run linting/tests)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
|
||||
29
.github/workflows/claude-dependabot.yml
vendored
29
.github/workflows/claude-dependabot.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -77,15 +77,27 @@ jobs:
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
@@ -112,7 +124,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
@@ -297,7 +309,6 @@ jobs:
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
allowed_bots: "dependabot[bot]"
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
|
||||
28
.github/workflows/claude.yml
vendored
28
.github/workflows/claude.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -93,15 +93,27 @@ jobs:
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
@@ -128,7 +140,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
6
.github/workflows/codeql.yml
vendored
6
.github/workflows/codeql.yml
vendored
@@ -58,11 +58,11 @@ jobs:
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v4
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
@@ -93,6 +93,6 @@ jobs:
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v4
|
||||
uses: github/codeql-action/analyze@v3
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
10
.github/workflows/copilot-setup-steps.yml
vendored
10
.github/workflows/copilot-setup-steps.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -89,7 +89,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
4
.github/workflows/docs-block-sync.yml
vendored
4
.github/workflows/docs-block-sync.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
38
.github/workflows/docs-claude-review.yml
vendored
38
.github/workflows/docs-claude-review.yml
vendored
@@ -7,10 +7,6 @@ on:
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
|
||||
concurrency:
|
||||
group: claude-docs-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Only run for PRs from members/collaborators
|
||||
@@ -27,7 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -37,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -95,35 +91,5 @@ jobs:
|
||||
3. Read corresponding documentation files to verify accuracy
|
||||
4. Provide your feedback as a PR comment
|
||||
|
||||
## IMPORTANT: Comment Marker
|
||||
Start your PR comment with exactly this HTML comment marker on its own line:
|
||||
<!-- CLAUDE_DOCS_REVIEW -->
|
||||
|
||||
This marker is used to identify and replace your comment on subsequent runs.
|
||||
|
||||
Be constructive and specific. If everything looks good, say so!
|
||||
If there are issues, explain what's wrong and suggest how to fix it.
|
||||
|
||||
- name: Delete old Claude review comments
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Get all comment IDs with our marker, sorted by creation date (oldest first)
|
||||
COMMENT_IDS=$(gh api \
|
||||
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
|
||||
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
|
||||
|
||||
# Count comments
|
||||
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
|
||||
|
||||
if [ "$COMMENT_COUNT" -gt 1 ]; then
|
||||
# Delete all but the last (newest) comment
|
||||
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
|
||||
if [ -n "$COMMENT_ID" ]; then
|
||||
echo "Deleting old review comment: $COMMENT_ID"
|
||||
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "No old review comments to clean up"
|
||||
fi
|
||||
|
||||
4
.github/workflows/docs-enhance.yml
vendored
4
.github/workflows/docs-enhance.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger deploy workflow
|
||||
uses: peter-evans/repository-dispatch@v4
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger deploy workflow
|
||||
uses: peter-evans/repository-dispatch@v4
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
|
||||
13
.github/workflows/platform-backend-ci.yml
vendored
13
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,18 +41,13 @@ jobs:
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
image: rabbitmq:4.1.4
|
||||
image: rabbitmq:3.12-management
|
||||
ports:
|
||||
- 5672:5672
|
||||
- 15672:15672
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
options: >-
|
||||
--health-cmd "rabbitmq-diagnostics -q ping"
|
||||
--health-interval 30s
|
||||
--health-timeout 10s
|
||||
--health-retries 5
|
||||
--health-start-period 10s
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
@@ -73,7 +68,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
@@ -93,7 +88,7 @@ jobs:
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check comment permissions and deployment status
|
||||
id: check_status
|
||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const commentBody = context.payload.comment.body.trim();
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Post permission denied comment
|
||||
if: steps.check_status.outputs.permission_denied == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Get PR details for deployment
|
||||
id: pr_details
|
||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
@@ -82,7 +82,7 @@ jobs:
|
||||
|
||||
- name: Dispatch Deploy Event
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: peter-evans/repository-dispatch@v4
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
|
||||
- name: Post deploy success comment
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -110,7 +110,7 @@ jobs:
|
||||
|
||||
- name: Dispatch Undeploy Event (from comment)
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: peter-evans/repository-dispatch@v4
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
|
||||
- name: Post undeploy success comment
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- name: Check deployment status on PR close
|
||||
id: check_pr_close
|
||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const comments = await github.rest.issues.listComments({
|
||||
@@ -168,7 +168,7 @@ jobs:
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: peter-evans/repository-dispatch@v4
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
263
.github/workflows/platform-frontend-ci.yml
vendored
263
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,16 +6,10 @@ on:
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
- "autogpt_platform/backend/Dockerfile"
|
||||
- "autogpt_platform/docker-compose.yml"
|
||||
- "autogpt_platform/docker-compose.platform.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
- "autogpt_platform/backend/Dockerfile"
|
||||
- "autogpt_platform/docker-compose.yml"
|
||||
- "autogpt_platform/docker-compose.platform.yml"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -32,11 +26,12 @@ jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
components-changed: ${{ steps.filter.outputs.components }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check for component changes
|
||||
uses: dorny/paths-filter@v3
|
||||
@@ -46,17 +41,28 @@ jobs:
|
||||
components:
|
||||
- 'autogpt_platform/frontend/src/components/**'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install dependencies to populate cache
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
@@ -65,17 +71,24 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
@@ -94,19 +107,26 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
@@ -121,20 +141,30 @@ jobs:
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
@@ -142,125 +172,77 @@ jobs:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-frontend-test-
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
- name: Move cache
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
rm -rf /tmp/.buildx-cache
|
||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
fi
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
- name: Create E2E test data
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
# First try to run the script from inside the container
|
||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||
# Copy the script into the container and run it
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||
echo "❌ Failed to copy script to container"
|
||||
exit 1
|
||||
}
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
@@ -287,7 +269,7 @@ jobs:
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -295,19 +277,26 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
16
.github/workflows/platform-fullstack-ci.yml
vendored
16
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -29,10 +29,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -56,19 +56,19 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: big-boi
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -85,10 +85,10 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
|
||||
39
.github/workflows/pr-overlap-check.yml
vendored
39
.github/workflows/pr-overlap-check.yml
vendored
@@ -1,39 +0,0 @@
|
||||
name: PR Overlap Detection
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
branches:
|
||||
- dev
|
||||
- master
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-overlaps:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Need full history for merge testing
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Configure git
|
||||
run: |
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config user.name "github-actions[bot]"
|
||||
|
||||
- name: Run overlap detection
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# Always succeed - this check informs contributors, it shouldn't block merging
|
||||
continue-on-error: true
|
||||
run: |
|
||||
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}
|
||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
steps:
|
||||
# - name: Wait some time for all actions to start
|
||||
# run: sleep 30
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
- name: Set up Python
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Add cache configuration to a resolved docker-compose file for all services
|
||||
that have a build key, and ensure image names match what docker compose expects.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_BRANCH = "dev"
|
||||
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add cache config to a resolved compose file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
required=True,
|
||||
help="Source compose file to read (should be output of `docker compose config`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-from",
|
||||
default="type=gha",
|
||||
help="Cache source configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-to",
|
||||
default="type=gha,mode=max",
|
||||
help="Cache destination configuration",
|
||||
)
|
||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||
parser.add_argument(
|
||||
f"--{component}-hash",
|
||||
default="",
|
||||
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--git-ref",
|
||||
default="",
|
||||
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
||||
git_ref_scope = ""
|
||||
if args.git_ref:
|
||||
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
||||
|
||||
with open(args.source, "r") as f:
|
||||
compose = yaml.safe_load(f)
|
||||
|
||||
# Get project name from compose file or default
|
||||
project_name = compose.get("name", "autogpt_platform")
|
||||
|
||||
def get_image_name(dockerfile: str, target: str) -> str:
|
||||
"""Generate image name based on Dockerfile folder and build target."""
|
||||
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
||||
if len(dockerfile_parts) >= 2:
|
||||
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
||||
else:
|
||||
folder_name = "app"
|
||||
return f"{project_name}-{folder_name}:{target}"
|
||||
|
||||
def get_build_key(dockerfile: str, target: str) -> str:
|
||||
"""Generate a unique key for a Dockerfile+target combination."""
|
||||
return f"{dockerfile}:{target}"
|
||||
|
||||
def get_component(dockerfile: str) -> str | None:
|
||||
"""Get component name (frontend/backend) from dockerfile path."""
|
||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||
if component in dockerfile:
|
||||
return component
|
||||
return None
|
||||
|
||||
# First pass: collect all services with build configs and identify duplicates
|
||||
# Track which (dockerfile, target) combinations we've seen
|
||||
build_key_to_first_service: dict[str, str] = {}
|
||||
services_to_build: list[str] = []
|
||||
services_to_dedupe: list[str] = []
|
||||
|
||||
for service_name, service_config in compose.get("services", {}).items():
|
||||
if "build" not in service_config:
|
||||
continue
|
||||
|
||||
build_config = service_config["build"]
|
||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||
target = build_config.get("target", "default")
|
||||
build_key = get_build_key(dockerfile, target)
|
||||
|
||||
if build_key not in build_key_to_first_service:
|
||||
# First service with this build config - it will do the actual build
|
||||
build_key_to_first_service[build_key] = service_name
|
||||
services_to_build.append(service_name)
|
||||
else:
|
||||
# Duplicate - will just use the image from the first service
|
||||
services_to_dedupe.append(service_name)
|
||||
|
||||
# Second pass: configure builds and deduplicate
|
||||
modified_services = []
|
||||
for service_name, service_config in compose.get("services", {}).items():
|
||||
if "build" not in service_config:
|
||||
continue
|
||||
|
||||
build_config = service_config["build"]
|
||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||
target = build_config.get("target", "latest")
|
||||
image_name = get_image_name(dockerfile, target)
|
||||
|
||||
# Set image name for all services (needed for both builders and deduped)
|
||||
service_config["image"] = image_name
|
||||
|
||||
if service_name in services_to_dedupe:
|
||||
# Remove build config - this service will use the pre-built image
|
||||
del service_config["build"]
|
||||
continue
|
||||
|
||||
# This service will do the actual build - add cache config
|
||||
cache_from_list = []
|
||||
cache_to_list = []
|
||||
|
||||
component = get_component(dockerfile)
|
||||
if not component:
|
||||
# Skip services that don't clearly match frontend/backend
|
||||
continue
|
||||
|
||||
# Get the hash for this component
|
||||
component_hash = getattr(args, f"{component}_hash")
|
||||
|
||||
# Scope format: platform-{component}-{target}-{hash|ref}
|
||||
# Example: platform-backend-server-abc123
|
||||
|
||||
if "type=gha" in args.cache_from:
|
||||
# 1. Primary: exact hash match (most specific)
|
||||
if component_hash:
|
||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
||||
|
||||
# 2. Fallback: branch-based cache
|
||||
if git_ref_scope:
|
||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
||||
|
||||
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
||||
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
||||
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
||||
|
||||
if "type=gha" in args.cache_to:
|
||||
# Write to both hash-based and branch-based scopes
|
||||
if component_hash:
|
||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
||||
|
||||
if git_ref_scope:
|
||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
||||
|
||||
# Ensure we have at least one cache source/target
|
||||
if not cache_from_list:
|
||||
cache_from_list.append(args.cache_from)
|
||||
if not cache_to_list:
|
||||
cache_to_list.append(args.cache_to)
|
||||
|
||||
build_config["cache_from"] = cache_from_list
|
||||
build_config["cache_to"] = cache_to_list
|
||||
modified_services.append(service_name)
|
||||
|
||||
# Write back to the same file
|
||||
with open(args.source, "w") as f:
|
||||
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
||||
for svc in modified_services:
|
||||
svc_config = compose["services"][svc]
|
||||
build_cfg = svc_config.get("build", {})
|
||||
cache_from_list = build_cfg.get("cache_from", ["none"])
|
||||
cache_to_list = build_cfg.get("cache_to", ["none"])
|
||||
print(f" - {svc}")
|
||||
print(f" image: {svc_config.get('image', 'N/A')}")
|
||||
print(f" cache_from: {cache_from_list}")
|
||||
print(f" cache_to: {cache_to_list}")
|
||||
if services_to_dedupe:
|
||||
print(
|
||||
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
||||
)
|
||||
for svc in services_to_dedupe:
|
||||
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -180,6 +180,4 @@ autogpt_platform/backend/settings.py
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.next
|
||||
@@ -1,10 +1,3 @@
|
||||
default_install_hook_types:
|
||||
- pre-commit
|
||||
- pre-push
|
||||
- post-checkout
|
||||
|
||||
default_stages: [pre-commit]
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
@@ -24,7 +17,6 @@ repos:
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
|
||||
- repo: local
|
||||
@@ -34,106 +26,49 @@ repos:
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Backend
|
||||
alias: poetry-install-platform-backend
|
||||
entry: poetry -C autogpt_platform/backend install
|
||||
# include autogpt_libs source (since it's a path dependency)
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
|
||||
poetry -C autogpt_platform/backend install
|
||||
'
|
||||
always_run: true
|
||||
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Libs
|
||||
alias: poetry-install-platform-libs
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
|
||||
poetry -C autogpt_platform/autogpt_libs install
|
||||
'
|
||||
always_run: true
|
||||
entry: poetry -C autogpt_platform/autogpt_libs install
|
||||
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: pnpm-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Frontend
|
||||
alias: pnpm-install-platform-frontend
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
|
||||
pnpm --prefix autogpt_platform/frontend install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
entry: poetry -C classic/original_autogpt install
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
entry: poetry -C classic/forge install
|
||||
files: ^classic/forge/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
'
|
||||
always_run: true
|
||||
entry: poetry -C classic/benchmark install
|
||||
files: ^classic/benchmark/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, Prisma client must be up-to-date.
|
||||
@@ -141,54 +76,12 @@ repos:
|
||||
- id: prisma-generate
|
||||
name: Prisma Generate - AutoGPT Platform - Backend
|
||||
alias: prisma-generate-platform-backend
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
|
||||
cd autogpt_platform/backend
|
||||
&& poetry run prisma generate
|
||||
&& poetry run gen-prisma-stub
|
||||
'
|
||||
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
|
||||
# include everything that triggers poetry install + the prisma schema
|
||||
always_run: true
|
||||
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: export-api-schema
|
||||
name: Export API schema - AutoGPT Platform - Backend -> Frontend
|
||||
alias: export-api-schema-platform
|
||||
entry: >
|
||||
bash -c '
|
||||
cd autogpt_platform/backend
|
||||
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
&& cd ../frontend
|
||||
&& pnpm prettier --write ./src/app/api/openapi.json
|
||||
'
|
||||
files: ^autogpt_platform/backend/
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: generate-api-client
|
||||
name: Generate API client - AutoGPT Platform - Frontend
|
||||
alias: generate-api-client-platform-frontend
|
||||
entry: >
|
||||
bash -c '
|
||||
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
|
||||
else
|
||||
git diff --quiet HEAD -- "$SCHEMA" && exit 0
|
||||
fi;
|
||||
cd autogpt_platform/frontend && pnpm generate:api
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.7.2
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
*.ign.*
|
||||
@@ -45,11 +45,6 @@ AutoGPT Platform is a monorepo containing:
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.auth_activities
|
||||
-- Looker source alias: ds49 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Tracks authentication events (login, logout, SSO, password
|
||||
-- reset, etc.) from Supabase's internal audit log.
|
||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.audit_log_entries — Supabase internal auth event log
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||
-- actor_id TEXT User ID who triggered the event
|
||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days from current date
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily login counts
|
||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||
-- FROM analytics.auth_activities
|
||||
-- WHERE action = 'login'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- SSO vs password login breakdown
|
||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||
-- WHERE action = 'login' GROUP BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
created_at,
|
||||
payload->>'actor_id' AS actor_id,
|
||||
payload->>'actor_via_sso' AS actor_via_sso,
|
||||
payload->>'action' AS action
|
||||
FROM auth.audit_log_entries
|
||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||
@@ -1,105 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.graph_execution
|
||||
-- Looker source alias: ds16 | Charts: 21
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per agent graph execution (last 90 days).
|
||||
-- Unpacks the JSONB stats column into individual numeric columns
|
||||
-- and normalises the executionStatus — runs that failed due to
|
||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||
-- to allow safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Execution UUID
|
||||
-- agentGraphId TEXT Agent graph UUID
|
||||
-- agentGraphVersion INT Graph version number
|
||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||
-- updatedAt TIMESTAMPTZ Last status update time
|
||||
-- userId TEXT Owner user UUID
|
||||
-- agentGraphName TEXT Human-readable agent name
|
||||
-- cputime DECIMAL Total CPU seconds consumed
|
||||
-- walltime DECIMAL Total wall-clock seconds
|
||||
-- node_count DECIMAL Number of nodes in the graph
|
||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||
-- execution_cost DECIMAL Credit cost of this execution
|
||||
-- correctness_score FLOAT AI correctness score (if available)
|
||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Daily execution counts by status
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||
-- FROM analytics.graph_execution
|
||||
-- GROUP BY 1, 2 ORDER BY 1;
|
||||
--
|
||||
-- -- Average cost per execution by agent
|
||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'COMPLETED'
|
||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||
--
|
||||
-- -- Top error messages
|
||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||
-- FROM analytics.graph_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ge."id" AS id,
|
||||
ge."agentGraphId" AS agentGraphId,
|
||||
ge."agentGraphVersion" AS agentGraphVersion,
|
||||
CASE
|
||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||
AND (
|
||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||
)
|
||||
THEN 'NO_CREDITS'
|
||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||
END AS executionStatus,
|
||||
ge."createdAt" AS createdAt,
|
||||
ge."updatedAt" AS updatedAt,
|
||||
ge."userId" AS userId,
|
||||
g."name" AS agentGraphName,
|
||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage
|
||||
FROM platform."AgentGraphExecution" ge
|
||||
LEFT JOIN platform."AgentGraph" g
|
||||
ON ge."agentGraphId" = g."id"
|
||||
AND ge."agentGraphVersion" = g."version"
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||
"userId", "agentGraphId",
|
||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||
FROM platform."LibraryAgent"
|
||||
WHERE "isDeleted" = FALSE
|
||||
AND "isArchived" = FALSE
|
||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -1,101 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.node_block_execution
|
||||
-- Looker source alias: ds14 | Charts: 11
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per node (block) execution (last 90 days).
|
||||
-- Unpacks stats JSONB and joins to identify which block type
|
||||
-- was run. For failed nodes, joins the error output and
|
||||
-- scrubs it for safe grouping.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentNodeExecution — Node execution records
|
||||
-- platform.AgentNode — Node → block mapping
|
||||
-- platform.AgentBlock — Block name/ID
|
||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Node execution UUID
|
||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||
-- agentNodeId TEXT Node UUID within the graph
|
||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||
-- addedTime TIMESTAMPTZ When the node was queued
|
||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||
-- startedTime TIMESTAMPTZ When execution started
|
||||
-- endedTime TIMESTAMPTZ When execution finished
|
||||
-- inputSize BIGINT Input payload size in bytes
|
||||
-- outputSize BIGINT Output payload size in bytes
|
||||
-- walltime NUMERIC Wall-clock seconds for this node
|
||||
-- cputime NUMERIC CPU seconds for this node
|
||||
-- llmRetryCount INT Number of LLM retries
|
||||
-- llmCallCount INT Number of LLM API calls made
|
||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||
-- outputTokenCount BIGINT LLM output tokens produced
|
||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||
-- blockId TEXT Block UUID
|
||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Most-used blocks by execution count
|
||||
-- SELECT "blockName", COUNT(*) AS executions,
|
||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||
-- FROM analytics.node_block_execution
|
||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||
--
|
||||
-- -- Average LLM token usage per block
|
||||
-- SELECT "blockName",
|
||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "llmCallCount" > 0
|
||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||
--
|
||||
-- -- Top failure reasons
|
||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||
-- FROM analytics.node_block_execution
|
||||
-- WHERE "executionStatus" = 'FAILED'
|
||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
ne."id" AS id,
|
||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||
ne."agentNodeId" AS agentNodeId,
|
||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||
ne."addedTime" AS addedTime,
|
||||
ne."queuedTime" AS queuedTime,
|
||||
ne."startedTime" AS startedTime,
|
||||
ne."endedTime" AS endedTime,
|
||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||
b."name" AS blockName,
|
||||
b."id" AS blockId,
|
||||
REGEXP_REPLACE(
|
||||
REGEXP_REPLACE(
|
||||
TRIM(BOTH '"' FROM eio."data"::text),
|
||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||
'\1\2/...', 'gi'
|
||||
),
|
||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||
) AS groupedErrorMessage,
|
||||
eio."data" AS errorMessage
|
||||
FROM platform."AgentNodeExecution" ne
|
||||
LEFT JOIN platform."AgentNode" nd
|
||||
ON ne."agentNodeId" = nd."id"
|
||||
LEFT JOIN platform."AgentBlock" b
|
||||
ON nd."agentBlockId" = b."id"
|
||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||
ON eio."referencedByOutputExecId" = ne."id"
|
||||
AND eio."name" = 'error'
|
||||
AND ne."executionStatus" = 'FAILED'
|
||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -1,97 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_agent
|
||||
-- Looker source alias: ds35 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention broken down per individual agent.
|
||||
-- Cohort = week of a user's first use of THAT specific agent.
|
||||
-- Tells you which agents keep users coming back vs. one-shot
|
||||
-- use. Only includes cohorts from the last 180 days.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||
-- platform.AgentGraph — Agent names
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- agent_id TEXT Agent graph UUID
|
||||
-- agent_label TEXT 'AgentName [first8chars]'
|
||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||
-- cohort_week_start DATE Week users first ran this agent
|
||||
-- cohort_label TEXT ISO week label
|
||||
-- cohort_label_n TEXT ISO week label with cohort size
|
||||
-- user_lifetime_week INT Weeks since first use of this agent
|
||||
-- cohort_users BIGINT Users in this cohort for this agent
|
||||
-- active_users BIGINT Users who ran the agent again in week k
|
||||
-- retention_rate FLOAT active_users / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Best-retained agents at week 2
|
||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||
-- FROM analytics.retention_agent
|
||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||
--
|
||||
-- -- Agents with most unique users
|
||||
-- SELECT DISTINCT agent_label, agent_total_users
|
||||
-- FROM analytics.retention_agent
|
||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||
e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e
|
||||
),
|
||||
first_use AS (
|
||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1,2
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||
),
|
||||
active_counts AS (
|
||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||
),
|
||||
cohort_sizes AS (
|
||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||
),
|
||||
cohort_caps AS (
|
||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
),
|
||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||
SELECT
|
||||
g.agent_id,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(ac.active_users,0) AS active_users,
|
||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||
FROM grid g
|
||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -1,81 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_daily
|
||||
-- Looker source alias: ds111 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on agent executions.
|
||||
-- Cohort anchor = day of user's FIRST ever execution.
|
||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||
-- Great for early engagement analysis (did users run another
|
||||
-- agent the next day?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_daily.
|
||||
-- cohort_day_start = day of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-3 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||
-- FROM analytics.retention_execution_daily
|
||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fe.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -1,81 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_execution_weekly
|
||||
-- Looker source alias: ds92 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on agent executions.
|
||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||
-- (not first login). Only includes cohorts from the last 180 days.
|
||||
-- Useful when you care about product engagement, not just visits.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.AgentGraphExecution — Execution records
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same pattern as retention_login_weekly.
|
||||
-- cohort_week_start = week of first execution (not first login)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-2 execution retention
|
||||
-- SELECT cohort_label, retention_rate_bounded
|
||||
-- FROM analytics.retention_execution_weekly
|
||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||
events AS (
|
||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||
),
|
||||
first_exec AS (
|
||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fe.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -1,94 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_daily
|
||||
-- Looker source alias: ds112 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Daily cohort retention based on login sessions.
|
||||
-- Same logic as retention_login_weekly but at day granularity,
|
||||
-- showing up to day 30 for cohorts from the last 90 days.
|
||||
-- Useful for analysing early activation (days 1-7) in detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||
-- cohort_day_start DATE First day the cohort logged in
|
||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||
-- cohort_users BIGINT Total users in cohort
|
||||
-- active_users_bounded BIGINT Users active on exactly day k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Day-1 retention rate (came back next day)
|
||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||
--
|
||||
-- -- Average retention curve across all cohorts
|
||||
-- SELECT user_lifetime_day,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_daily
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||
FROM events GROUP BY 1
|
||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||
),
|
||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||
user_day_age AS (
|
||||
SELECT ad.user_id, fl.cohort_day_start,
|
||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_day_start,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_day, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||
@@ -1,96 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_onboarded_weekly
|
||||
-- Looker source alias: ds101 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention from login sessions, restricted to
|
||||
-- users who "onboarded" — defined as running at least one
|
||||
-- agent within 365 days of their first login.
|
||||
-- Filters out users who signed up but never activated,
|
||||
-- giving a cleaner view of engaged-user retention.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||
-- Only difference: cohort is filtered to onboarded users only.
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Compare week-4 retention: all users vs onboarded only
|
||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||
-- UNION ALL
|
||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login_all AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
onboarders AS (
|
||||
SELECT fl.user_id FROM first_login_all fl
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||
WHERE e."userId"::text = fl.user_id
|
||||
AND e."createdAt" >= fl.first_login_time
|
||||
AND e."createdAt" < fl.first_login_time
|
||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||
)
|
||||
),
|
||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||
@@ -1,103 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.retention_login_weekly
|
||||
-- Looker source alias: ds83 | Charts: 2
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Weekly cohort retention based on login sessions.
|
||||
-- Users are grouped by the ISO week of their first ever login.
|
||||
-- For each cohort × lifetime-week combination, outputs both:
|
||||
-- - bounded rate: % active in exactly that week
|
||||
-- - unbounded rate: % who were ever active on or after that week
|
||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login session records
|
||||
--
|
||||
-- HOW TO READ THE OUTPUT
|
||||
-- cohort_week_start The Monday of the week users first logged in
|
||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- cohort_week_start DATE First day of the cohort's signup week
|
||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||
-- active_users_bounded BIGINT Users active in exactly week k
|
||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Week-1 retention rate per cohort
|
||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- WHERE user_lifetime_week = 1
|
||||
-- ORDER BY cohort_week_start;
|
||||
--
|
||||
-- -- Overall average retention curve (all cohorts combined)
|
||||
-- SELECT user_lifetime_week,
|
||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||
-- FROM analytics.retention_login_weekly
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
WITH params AS (SELECT 12::int AS max_weeks),
|
||||
events AS (
|
||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||
),
|
||||
first_login AS (
|
||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||
FROM events GROUP BY 1
|
||||
),
|
||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||
user_week_age AS (
|
||||
SELECT aw.user_id, fl.cohort_week_start,
|
||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||
),
|
||||
bounded_counts AS (
|
||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||
),
|
||||
last_active AS (
|
||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||
),
|
||||
unbounded_counts AS (
|
||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||
FROM last_active la
|
||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||
GROUP BY 1,2
|
||||
),
|
||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||
cohort_caps AS (
|
||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||
LEAST((SELECT max_weeks FROM params),
|
||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||
FROM cohort_sizes cs
|
||||
),
|
||||
grid AS (
|
||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||
)
|
||||
SELECT
|
||||
g.cohort_week_start,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||
g.user_lifetime_week, g.cohort_users,
|
||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||
FROM grid g
|
||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||
@@ -1,71 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_block_spending
|
||||
-- Looker source alias: ds6 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per credit transaction (last 90 days).
|
||||
-- Shows how users spend credits broken down by block type,
|
||||
-- LLM provider and model. Joins node execution stats for
|
||||
-- token-level detail.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.CreditTransaction — Credit debit/credit records
|
||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- transactionKey TEXT Unique transaction identifier
|
||||
-- userId TEXT User who was charged
|
||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||
-- blockId TEXT Block UUID that triggered the spend
|
||||
-- blockName TEXT Human-readable block name
|
||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||
-- node_exec_id TEXT Linked node execution UUID
|
||||
-- llm_call_count INT LLM API calls made in that execution
|
||||
-- llm_retry_count INT LLM retries in that execution
|
||||
-- llm_input_token_count INT Input tokens consumed
|
||||
-- llm_output_token_count INT Output tokens produced
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend per user (last 90 days)
|
||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "transactionType" = 'USAGE'
|
||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||
--
|
||||
-- -- Spend by LLM provider + model
|
||||
-- SELECT "llm_provider", "llm_model",
|
||||
-- SUM("negativeAmount") AS total_cost,
|
||||
-- SUM("llm_input_token_count") AS input_tokens,
|
||||
-- SUM("llm_output_token_count") AS output_tokens
|
||||
-- FROM analytics.user_block_spending
|
||||
-- WHERE "llm_provider" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
c."transactionKey" AS transactionKey,
|
||||
c."userId" AS userId,
|
||||
c."amount" AS amount,
|
||||
c."amount" * -1 AS negativeAmount,
|
||||
c."type" AS transactionType,
|
||||
c."createdAt" AS transactionTime,
|
||||
c.metadata->>'block_id' AS blockId,
|
||||
c.metadata->>'block' AS blockName,
|
||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||
c.metadata->'input'->>'model' AS llm_model,
|
||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||
FROM platform."CreditTransaction" c
|
||||
LEFT JOIN platform."AgentNodeExecution" ne
|
||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -1,45 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding
|
||||
-- Looker source alias: ds68 | Charts: 3
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user onboarding record. Contains the user's
|
||||
-- stated usage reason, selected integrations, completed
|
||||
-- onboarding steps and optional first agent selection.
|
||||
-- Full history (no date filter) since onboarding happens
|
||||
-- once per user.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding state per user
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Onboarding record UUID
|
||||
-- createdAt TIMESTAMPTZ When onboarding started
|
||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||
-- integrations TEXT[] Array of integration names the user selected
|
||||
-- userId TEXT User UUID
|
||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Usage reason breakdown
|
||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||
--
|
||||
-- -- Completion rate per step
|
||||
-- SELECT step, COUNT(*) AS users_completed
|
||||
-- FROM analytics.user_onboarding
|
||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
id,
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"usageReason",
|
||||
integrations,
|
||||
"userId",
|
||||
"completedSteps",
|
||||
"selectedStoreListingVersionId"
|
||||
FROM platform."UserOnboarding"
|
||||
@@ -1,100 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_funnel
|
||||
-- Looker source alias: ds74 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated onboarding funnel showing how many users
|
||||
-- completed each step and the drop-off percentage from the
|
||||
-- previous step. One row per onboarding step (all 22 steps
|
||||
-- always present, even with 0 completions — prevents sparse
|
||||
-- gaps from making LAG compare the wrong predecessors).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||
-- users_completed BIGINT Distinct users who completed this step
|
||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||
--
|
||||
-- STEP ORDER
|
||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full funnel
|
||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||
--
|
||||
-- -- Biggest drop-off point
|
||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||
-- =============================================================
|
||||
|
||||
WITH all_steps AS (
|
||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||
-- are always present, keeping LAG comparisons correct.
|
||||
SELECT step_name, step_order
|
||||
FROM (VALUES
|
||||
('WELCOME', 1),
|
||||
('USAGE_REASON', 2),
|
||||
('INTEGRATIONS', 3),
|
||||
('AGENT_CHOICE', 4),
|
||||
('AGENT_NEW_RUN', 5),
|
||||
('AGENT_INPUT', 6),
|
||||
('CONGRATS', 7),
|
||||
('GET_RESULTS', 8),
|
||||
('MARKETPLACE_VISIT', 9),
|
||||
('MARKETPLACE_ADD_AGENT', 10),
|
||||
('MARKETPLACE_RUN_AGENT', 11),
|
||||
('BUILDER_OPEN', 12),
|
||||
('BUILDER_SAVE_AGENT', 13),
|
||||
('BUILDER_RUN_AGENT', 14),
|
||||
('VISIT_COPILOT', 15),
|
||||
('RE_RUN_AGENT', 16),
|
||||
('SCHEDULE_AGENT', 17),
|
||||
('RUN_AGENTS', 18),
|
||||
('RUN_3_DAYS', 19),
|
||||
('TRIGGER_WEBHOOK', 20),
|
||||
('RUN_14_DAYS', 21),
|
||||
('RUN_AGENTS_100', 22)
|
||||
) AS t(step_name, step_order)
|
||||
),
|
||||
raw AS (
|
||||
SELECT
|
||||
u."userId",
|
||||
step_txt::text AS step
|
||||
FROM platform."UserOnboarding" u
|
||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
),
|
||||
step_counts AS (
|
||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||
FROM raw GROUP BY step
|
||||
),
|
||||
funnel AS (
|
||||
SELECT
|
||||
a.step_name AS step,
|
||||
a.step_order,
|
||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||
ROUND(
|
||||
100.0 * COALESCE(sc.users_completed, 0)
|
||||
/ NULLIF(
|
||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||
0
|
||||
),
|
||||
2
|
||||
) AS pct_from_prev
|
||||
FROM all_steps a
|
||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||
)
|
||||
SELECT * FROM funnel ORDER BY step_order
|
||||
@@ -1,41 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.user_onboarding_integration
|
||||
-- Looker source alias: ds75 | Charts: 1
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- Pre-aggregated count of users who selected each integration
|
||||
-- during onboarding. One row per integration type, sorted
|
||||
-- by popularity.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.UserOnboarding — integrations array column
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||
--
|
||||
-- WINDOW
|
||||
-- Users who started onboarding in the last 90 days
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Full integration popularity ranking
|
||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||
--
|
||||
-- -- Top 5 integrations
|
||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||
-- =============================================================
|
||||
|
||||
WITH exploded AS (
|
||||
SELECT
|
||||
u."userId" AS user_id,
|
||||
UNNEST(u."integrations") AS integration
|
||||
FROM platform."UserOnboarding" u
|
||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||
)
|
||||
SELECT
|
||||
integration,
|
||||
COUNT(DISTINCT user_id) AS users_with_integration
|
||||
FROM exploded
|
||||
WHERE integration IS NOT NULL AND integration <> ''
|
||||
GROUP BY integration
|
||||
ORDER BY users_with_integration DESC
|
||||
@@ -1,145 +0,0 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.users_activities
|
||||
-- Looker source alias: ds56 | Charts: 5
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per user with lifetime activity summary.
|
||||
-- Joins login sessions with agent graphs, executions and
|
||||
-- node-level runs to give a full picture of how engaged
|
||||
-- each user is. Includes a convenience flag for 7-day
|
||||
-- activation (did the user return at least 7 days after
|
||||
-- their first login?).
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- auth.sessions — Login/session records
|
||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||
-- platform.AgentGraphExecution — Agent run history
|
||||
-- platform.AgentNodeExecution — Individual block execution history
|
||||
--
|
||||
-- PERFORMANCE NOTE
|
||||
-- Each CTE aggregates its own table independently by userId.
|
||||
-- This avoids the fan-out that occurs when driving every join
|
||||
-- from user_logins across the two largest tables
|
||||
-- (AgentGraphExecution and AgentNodeExecution).
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- user_id TEXT Supabase user UUID
|
||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||
-- node_execution_count BIGINT Total node executions across all runs
|
||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Users who ran at least one agent and returned after 7 days
|
||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||
--
|
||||
-- -- Top 10 most active users by agent runs
|
||||
-- SELECT user_id, agent_runs, node_execution_count
|
||||
-- FROM analytics.users_activities
|
||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||
--
|
||||
-- -- 7-day activation rate
|
||||
-- SELECT
|
||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||
-- AS activation_rate
|
||||
-- FROM analytics.users_activities;
|
||||
-- =============================================================
|
||||
|
||||
WITH user_logins AS (
|
||||
SELECT
|
||||
user_id::text AS user_id,
|
||||
MIN(created_at) AS first_login_time,
|
||||
MAX(created_at) AS last_login_time,
|
||||
GREATEST(
|
||||
MAX(refreshed_at)::timestamptz,
|
||||
MAX(created_at)::timestamptz
|
||||
) AS last_visit_time
|
||||
FROM auth.sessions
|
||||
GROUP BY user_id
|
||||
),
|
||||
user_agents AS (
|
||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MAX("updatedAt") AS last_agent_save_time,
|
||||
COUNT(DISTINCT "id") AS agent_count
|
||||
FROM platform."AgentGraph"
|
||||
WHERE "isActive"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_graph_runs AS (
|
||||
-- Aggregate AgentGraphExecution directly by userId
|
||||
SELECT
|
||||
"userId"::text AS user_id,
|
||||
MIN("createdAt") AS first_agent_run_time,
|
||||
MAX("createdAt") AS last_agent_run_time,
|
||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||
COUNT("id") AS agent_runs
|
||||
FROM platform."AgentGraphExecution"
|
||||
GROUP BY "userId"
|
||||
),
|
||||
user_node_runs AS (
|
||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||
-- single join to AgentGraphExecution instead of fanning out from
|
||||
-- user_logins through both large tables.
|
||||
SELECT
|
||||
g."userId"::text AS user_id,
|
||||
COUNT(*) AS node_execution_count,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||
FROM platform."AgentNodeExecution" n
|
||||
JOIN platform."AgentGraphExecution" g
|
||||
ON g."id" = n."agentGraphExecutionId"
|
||||
GROUP BY g."userId"
|
||||
)
|
||||
SELECT
|
||||
ul.user_id,
|
||||
ul.first_login_time,
|
||||
ul.last_login_time,
|
||||
ul.last_visit_time,
|
||||
ua.last_agent_save_time,
|
||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||
gr.first_agent_run_time,
|
||||
gr.last_agent_run_time,
|
||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||
CASE
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||
ELSE NULL
|
||||
END AS is_active_after_7d,
|
||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||
FROM user_logins ul
|
||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||
1857
autogpt_platform/autogpt_libs/poetry.lock
generated
1857
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^46.0"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.7"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.15.0"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.28.0"
|
||||
uvicorn = "^0.40.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.408"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== SIGNUP / INVITE GATE ===== ##
|
||||
# Set to true to require an invite before users can sign up
|
||||
ENABLE_INVITE_GATE=false
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
@@ -108,12 +104,6 @@ TWITTER_CLIENT_SECRET=
|
||||
# Make a new workspace for your OAuth APP -- trust me
|
||||
# https://linear.app/settings/api/applications/new
|
||||
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||
LINEAR_API_KEY=
|
||||
# Linear project and team IDs for the feature request tracker.
|
||||
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
|
||||
# and in team settings. Used by the chat copilot to file and search feature requests.
|
||||
LINEAR_FEATURE_REQUEST_PROJECT_ID=
|
||||
LINEAR_FEATURE_REQUEST_TEAM_ID=
|
||||
LINEAR_CLIENT_ID=
|
||||
LINEAR_CLIENT_SECRET=
|
||||
|
||||
@@ -194,8 +184,5 @@ ZEROBOUNCE_API_KEY=
|
||||
POSTHOG_API_KEY=
|
||||
POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# Tally Form Integration (pre-populate business understanding on signup)
|
||||
TALLY_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# ============================ DEPENDENCY BUILDER ============================ #
|
||||
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
@@ -53,106 +51,60 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
|
||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||
FROM debian:13-slim AS migrate
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy Node.js from builder (needed for Prisma CLI)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install prisma-client-py directly (much smaller than copying full venv)
|
||||
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
|
||||
FROM debian:13-slim AS server
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
jq \
|
||||
ripgrep \
|
||||
tree \
|
||||
bubblewrap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma and agent-browser.
|
||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||
# proper symlinks so npm/npx can find their modules.
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
# The .venv includes the generated Prisma client
|
||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
FROM server_dependencies AS migrate
|
||||
|
||||
# Copy dependency files + autogpt_libs (path dependency)
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
# Copy backend code + docs (for Copilot docs search)
|
||||
COPY autogpt_platform/backend ./
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
COPY docs /app/docs
|
||||
# Install the project package to create entry point scripts in .venv/bin/
|
||||
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
||||
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
poetry install --no-ansi --only-root
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["rest"]
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
"""Common test fixtures for server tests."""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -88,23 +88,20 @@ async def require_auth(
|
||||
)
|
||||
|
||||
|
||||
def require_permission(*permissions: APIKeyPermission):
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking required permissions.
|
||||
All listed permissions must be present.
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permissions(
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
missing = [p for p in permissions if p not in auth.scopes]
|
||||
if missing:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission(s): "
|
||||
f"{', '.join(p.value for p in missing)}",
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permissions
|
||||
return check_permission
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -9,17 +9,15 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.blocks
|
||||
from backend.api.external.middleware import require_auth, require_permission
|
||||
import backend.data.block
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
@@ -69,7 +67,7 @@ async def get_user_info(
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.blocks.get_blocks().values()]
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
return [b.to_dict() for b in blocks if not b.disabled]
|
||||
|
||||
|
||||
@@ -85,7 +83,7 @@ async def execute_graph_block(
|
||||
require_permission(APIKeyPermission.EXECUTE_BLOCK)
|
||||
),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.blocks.get_block(block_id)
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
if obj.disabled:
|
||||
@@ -97,43 +95,6 @@ async def execute_graph_block(
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs",
|
||||
tags=["graphs"],
|
||||
status_code=201,
|
||||
dependencies=[
|
||||
Security(
|
||||
require_permission(
|
||||
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
async def create_graph(
|
||||
graph: graph_db.Graph,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
|
||||
),
|
||||
) -> graph_db.GraphModel:
|
||||
"""
|
||||
Create a new agent graph.
|
||||
|
||||
The graph will be validated and assigned a new ID.
|
||||
It is automatically added to the user's library.
|
||||
"""
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
graph_model = graph_db.make_graph_model(graph, auth.user_id)
|
||||
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
||||
graph_model.validate_graph(for_run=False)
|
||||
|
||||
await graph_db.create_graph(graph_model, user_id=auth.user_id)
|
||||
await library_db.create_library_agent(graph_model, auth.user_id)
|
||||
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
|
||||
|
||||
return activated_graph
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
@@ -231,13 +192,13 @@ async def get_graph_execution_results(
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
@@ -279,7 +240,7 @@ async def get_store_agents(
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
@@ -307,13 +268,13 @@ async def get_store_agent(
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
@@ -349,7 +310,7 @@ async def get_store_creators(
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import prisma.enums
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||
|
||||
|
||||
class UserHistoryResponse(BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
||||
class AddUserCreditsResponse(BaseModel):
|
||||
new_balance: int
|
||||
transaction_key: str
|
||||
|
||||
|
||||
class CreateInvitedUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class InvitedUserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
status: prisma.enums.InvitedUserStatus
|
||||
auth_user_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tally_understanding: Optional[dict[str, Any]] = None
|
||||
tally_status: prisma.enums.TallyComputationStatus
|
||||
tally_computed_at: Optional[datetime] = None
|
||||
tally_error: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||
return cls.model_validate(record.model_dump())
|
||||
|
||||
|
||||
class InvitedUsersResponse(BaseModel):
|
||||
invited_users: list[InvitedUserResponse]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class BulkInvitedUserRowResponse(BaseModel):
|
||||
row_number: int
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||
message: str
|
||||
invited_user: Optional[InvitedUserResponse] = None
|
||||
|
||||
|
||||
class BulkInvitedUsersResponse(BaseModel):
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
error_count: int
|
||||
results: list[BulkInvitedUserRowResponse]
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||
return cls(
|
||||
created_count=result.created_count,
|
||||
skipped_count=result.skipped_count,
|
||||
error_count=result.error_count,
|
||||
results=[
|
||||
BulkInvitedUserRowResponse(
|
||||
row_number=row.row_number,
|
||||
email=row.email,
|
||||
name=row.name,
|
||||
status=row.status,
|
||||
message=row.message,
|
||||
invited_user=(
|
||||
InvitedUserResponse.from_record(row.invited_user)
|
||||
if row.invited_user is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for row in result.results
|
||||
],
|
||||
)
|
||||
|
||||
@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||
):
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
Paginated listings with their versions
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
try:
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmissionAdminView:
|
||||
):
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
@@ -73,24 +84,31 @@ async def review_submission(
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmissionAdminView with updated review information
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
try:
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches whenever approval state changes, since store visibility can change
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||
|
||||
from backend.data.invited_user import (
|
||||
bulk_create_invited_users_from_file,
|
||||
create_invited_user,
|
||||
list_invited_users,
|
||||
retry_invited_user_tally,
|
||||
revoke_invited_user,
|
||||
)
|
||||
from backend.data.tally import mask_email
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BulkInvitedUsersResponse,
|
||||
CreateInvitedUserRequest,
|
||||
InvitedUserResponse,
|
||||
InvitedUsersResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["users", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/invited-users",
|
||||
response_model=InvitedUsersResponse,
|
||||
summary="List Invited Users",
|
||||
)
|
||||
async def get_invited_users(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
) -> InvitedUsersResponse:
|
||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||
return InvitedUsersResponse(
|
||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=max(1, math.ceil(total / page_size)),
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Create Invited User",
|
||||
)
|
||||
async def create_invited_user_route(
|
||||
request: CreateInvitedUserRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s creating invited user for %s",
|
||||
admin_user_id,
|
||||
mask_email(request.email),
|
||||
)
|
||||
invited_user = await create_invited_user(request.email, request.name)
|
||||
logger.info(
|
||||
"Admin user %s created invited user %s",
|
||||
admin_user_id,
|
||||
invited_user.id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/bulk",
|
||||
response_model=BulkInvitedUsersResponse,
|
||||
summary="Bulk Create Invited Users",
|
||||
operation_id="postV2BulkCreateInvitedUsers",
|
||||
)
|
||||
async def bulk_create_invited_users_route(
|
||||
file: UploadFile = File(...),
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> BulkInvitedUsersResponse:
|
||||
logger.info(
|
||||
"Admin user %s bulk invited users from %s",
|
||||
admin_user_id,
|
||||
file.filename or "<unnamed>",
|
||||
)
|
||||
content = await file.read()
|
||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||
return BulkInvitedUsersResponse.from_result(result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/revoke",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Revoke Invited User",
|
||||
)
|
||||
async def revoke_invited_user_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||
)
|
||||
invited_user = await revoke_invited_user(invited_user_id)
|
||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invited-users/{invited_user_id}/retry-tally",
|
||||
response_model=InvitedUserResponse,
|
||||
summary="Retry Invited User Tally",
|
||||
)
|
||||
async def retry_invited_user_tally_route(
|
||||
invited_user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> InvitedUserResponse:
|
||||
logger.info(
|
||||
"Admin user %s retrying Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||
logger.info(
|
||||
"Admin user %s retried Tally seed for invited user %s",
|
||||
admin_user_id,
|
||||
invited_user_id,
|
||||
)
|
||||
return InvitedUserResponse.from_record(invited_user)
|
||||
@@ -1,168 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.invited_user import (
|
||||
BulkInvitedUserRowResult,
|
||||
BulkInvitedUsersResult,
|
||||
InvitedUserRecord,
|
||||
)
|
||||
|
||||
from .user_admin_routes import router as user_admin_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(user_admin_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _sample_invited_user() -> InvitedUserRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return InvitedUserRecord(
|
||||
id="invite-1",
|
||||
email="invited@example.com",
|
||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||
auth_user_id=None,
|
||||
name="Invited User",
|
||||
tally_understanding=None,
|
||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||
tally_computed_at=None,
|
||||
tally_error=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||
return BulkInvitedUsersResult(
|
||||
created_count=1,
|
||||
skipped_count=1,
|
||||
error_count=0,
|
||||
results=[
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=1,
|
||||
email="invited@example.com",
|
||||
name=None,
|
||||
status="CREATED",
|
||||
message="Invite created",
|
||||
invited_user=_sample_invited_user(),
|
||||
),
|
||||
BulkInvitedUserRowResult(
|
||||
row_number=2,
|
||||
email="duplicate@example.com",
|
||||
name=None,
|
||||
status="SKIPPED",
|
||||
message="An invited user with this email already exists",
|
||||
invited_user=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_get_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/invited-users")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["invited_users"]) == 1
|
||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||
assert data["invited_users"][0]["status"] == "INVITED"
|
||||
assert data["pagination"]["total_items"] == 1
|
||||
assert data["pagination"]["current_page"] == 1
|
||||
assert data["pagination"]["page_size"] == 50
|
||||
|
||||
|
||||
def test_create_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||
AsyncMock(return_value=_sample_invited_user()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users",
|
||||
json={"email": "invited@example.com", "name": "Invited User"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == "invited@example.com"
|
||||
assert data["name"] == "Invited User"
|
||||
|
||||
|
||||
def test_bulk_create_invited_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/invited-users/bulk",
|
||||
files={
|
||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert data["skipped_count"] == 1
|
||||
assert data["results"][0]["status"] == "CREATED"
|
||||
assert data["results"][1]["status"] == "SKIPPED"
|
||||
|
||||
|
||||
def test_revoke_invited_user(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
revoked = _sample_invited_user().model_copy(
|
||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||
AsyncMock(return_value=revoked),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "REVOKED"
|
||||
|
||||
|
||||
def test_retry_invited_user_tally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
retried = _sample_invited_user().model_copy(
|
||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||
AsyncMock(return_value=retried),
|
||||
)
|
||||
|
||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["tally_status"] == "RUNNING"
|
||||
@@ -1,26 +1,20 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
from typing import Sequence
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
BlockCategory,
|
||||
BlockInfo,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -28,7 +22,7 @@ from backend.util.models import Pagination
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
BlockResponse,
|
||||
BlockTypeFilter,
|
||||
BlockType,
|
||||
CountResponse,
|
||||
FilterType,
|
||||
Provider,
|
||||
@@ -43,16 +37,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||
|
||||
# Boost blocks over marketplace agents in search results
|
||||
BLOCK_SCORE_BOOST = 50.0
|
||||
|
||||
# Block IDs to exclude from search results
|
||||
EXCLUDED_BLOCK_IDS = frozenset(
|
||||
{
|
||||
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
|
||||
}
|
||||
)
|
||||
|
||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||
|
||||
|
||||
@@ -75,8 +59,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
# Skip disabled and excluded blocks
|
||||
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't have categories (all should have at least one)
|
||||
if not block.categories:
|
||||
@@ -104,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
def get_blocks(
|
||||
*,
|
||||
category: str | None = None,
|
||||
type: BlockTypeFilter | None = None,
|
||||
type: BlockType | None = None,
|
||||
provider: ProviderName | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
@@ -127,9 +111,6 @@ def get_blocks(
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip excluded blocks
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Skip blocks that don't match the category
|
||||
if category and category not in {c.name.lower() for c in block.categories}:
|
||||
continue
|
||||
@@ -269,25 +250,14 @@ async def _build_cached_search_results(
|
||||
"my_agents": 0,
|
||||
}
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
elif include_blocks or include_integrations:
|
||||
# No query - list all blocks using in-memory approach
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
normalized_query=normalized_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
|
||||
if include_library_agents:
|
||||
library_response = await library_db.list_library_agents(
|
||||
@@ -332,14 +302,10 @@ async def _build_cached_search_results(
|
||||
|
||||
def _collect_block_results(
|
||||
*,
|
||||
normalized_query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Collect all blocks for listing (no search query).
|
||||
|
||||
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
@@ -352,10 +318,6 @@ def _collect_block_results(
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Skip excluded blocks
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
block_info = block.get_info()
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
is_integration = len(credentials) > 0
|
||||
@@ -365,6 +327,10 @@ def _collect_block_results(
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
score = _score_block(block, block_info, normalized_query)
|
||||
if not _should_include_item(score, normalized_query):
|
||||
continue
|
||||
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
@@ -375,122 +341,8 @@ def _collect_block_results(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=BLOCK_SCORE_BOOST,
|
||||
sort_key=block_info.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
score=score,
|
||||
sort_key=_get_item_name(block_info),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -615,8 +467,6 @@ async def _get_static_counts():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
all_blocks += 1
|
||||
|
||||
@@ -643,25 +493,47 @@ async def _get_static_counts():
|
||||
}
|
||||
|
||||
|
||||
def _contains_type(annotation: Any, target: type) -> bool:
|
||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||
if annotation is target:
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return False
|
||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||
|
||||
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if _contains_type(field.annotation, LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _score_block(
|
||||
block: AnyBlockSchema,
|
||||
block_info: BlockInfo,
|
||||
normalized_query: str,
|
||||
) -> float:
|
||||
if not normalized_query:
|
||||
return 0.0
|
||||
|
||||
name = block_info.name.lower()
|
||||
description = block_info.description.lower()
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
category_text = " ".join(
|
||||
category.get("category", "").lower() for category in block_info.categories
|
||||
)
|
||||
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
for info in credentials_info
|
||||
for provider in info.provider
|
||||
]
|
||||
provider_text = " ".join(provider_names)
|
||||
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||
|
||||
if _matches_llm_model(block.input_schema, normalized_query):
|
||||
score += 20
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _score_library_agent(
|
||||
agent: library_model.LibraryAgent,
|
||||
normalized_query: str,
|
||||
@@ -768,32 +640,45 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
return providers
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600, shared_cache=True)
|
||||
@cached(ttl_seconds=3600)
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
"""Return the most-executed blocks from the last 14 days.
|
||||
suggested_blocks = []
|
||||
# Sum the number of executions for each block type
|
||||
# Prisma cannot group by nested relations, so we do a raw query
|
||||
# Calculate the cutoff timestamp
|
||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
|
||||
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
|
||||
and returns the top `count` blocks sorted by execution count, excluding
|
||||
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
|
||||
"""
|
||||
results = await mv_suggested_blocks.prisma().find_many()
|
||||
results = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
agent_node."agentBlockId" AS block_id,
|
||||
COUNT(execution.id) AS execution_count
|
||||
FROM {schema_prefix}"AgentNodeExecution" execution
|
||||
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||
WHERE execution."endedTime" >= $1::timestamp
|
||||
GROUP BY agent_node."agentBlockId"
|
||||
ORDER BY execution_count DESC;
|
||||
""",
|
||||
timestamp_threshold,
|
||||
)
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input, Output, Agent, and excluded blocks
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
execution_counts = {row.block_id: row.execution_count for row in results}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled or block.block_type in (
|
||||
BlockType.INPUT,
|
||||
BlockType.OUTPUT,
|
||||
BlockType.AGENT,
|
||||
backend.data.block.BlockType.INPUT,
|
||||
backend.data.block.BlockType.OUTPUT,
|
||||
backend.data.block.BlockType.AGENT,
|
||||
):
|
||||
continue
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
execution_count = execution_counts.get(block.id, 0)
|
||||
# Find the execution count for this block
|
||||
execution_count = next(
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.blocks._base import BlockInfo
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
@@ -15,7 +15,7 @@ FilterType = Literal[
|
||||
"my_agents",
|
||||
]
|
||||
|
||||
BlockTypeFilter = Literal["all", "input", "action", "output"]
|
||||
BlockType = Literal["all", "input", "action", "output"]
|
||||
|
||||
|
||||
class SearchEntry(BaseModel):
|
||||
@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[SearchEntry]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockInfo]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Annotated, Sequence, cast, get_args
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
|
||||
from . import db as builder_db
|
||||
from . import model as builder_model
|
||||
|
||||
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
@@ -51,6 +49,11 @@ async def get_suggestions(
|
||||
Get all suggestions for the Blocks Menu.
|
||||
"""
|
||||
return builder_model.SuggestionsResponse(
|
||||
otto_suggestions=[
|
||||
"What blocks do I need to get started?",
|
||||
"Help me create a list",
|
||||
"Help me feed my data to Google Maps",
|
||||
],
|
||||
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||
providers=[
|
||||
ProviderName.TWITTER,
|
||||
@@ -85,7 +88,7 @@ async def get_block_categories(
|
||||
)
|
||||
async def get_blocks(
|
||||
category: Annotated[str | None, fastapi.Query()] = None,
|
||||
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
|
||||
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
|
||||
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
@@ -148,7 +151,7 @@ async def get_providers(
|
||||
async def search(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
@@ -157,20 +160,9 @@ async def search(
|
||||
"""
|
||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||
"""
|
||||
# Parse and validate filter parameter
|
||||
filters: list[builder_model.FilterType]
|
||||
if filter:
|
||||
filter_values = [f.strip() for f in filter.split(",")]
|
||||
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
|
||||
if invalid_filters:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
|
||||
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
|
||||
)
|
||||
filters = cast(list[builder_model.FilterType], filter_values)
|
||||
else:
|
||||
filters = [
|
||||
# If no filters are provided, then we will return all types
|
||||
if not filter:
|
||||
filter = [
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
@@ -182,7 +174,7 @@ async def search(
|
||||
cached_results = await builder_db.get_sorted_search_results(
|
||||
user_id=user_id,
|
||||
search_query=search_query,
|
||||
filters=filters,
|
||||
filters=filter,
|
||||
by_creator=by_creator,
|
||||
)
|
||||
|
||||
@@ -204,7 +196,7 @@ async def search(
|
||||
user_id,
|
||||
builder_model.SearchEntry(
|
||||
search_query=search_query,
|
||||
filter=filters,
|
||||
filter=filter,
|
||||
by_creator=by_creator,
|
||||
search_id=search_id,
|
||||
),
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -0,0 +1,344 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
prisma_client: Prisma | None,
|
||||
) -> None:
|
||||
"""Update tool message in database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The caller should
|
||||
handle this to avoid marking the task as completed with inconsistent state.
|
||||
"""
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
updated_count = await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={"content": content},
|
||||
)
|
||||
# Check if any rows were updated - 0 means message not found
|
||||
if updated_count == 0:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=content,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task will be
|
||||
marked as failed instead of completed to avoid inconsistent state.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
146
autogpt_platform/backend/backend/api/features/chat/config.py
Normal file
146
autogpt_platform/backend/backend/api/features/chat/config.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if v is None:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
if not v:
|
||||
# Fall back to OPEN_ROUTER_API_KEY
|
||||
v = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if v is None:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
v = os.getenv("OPENROUTER_BASE_URL")
|
||||
if not v:
|
||||
v = os.getenv("OPENAI_BASE_URL")
|
||||
if not v:
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
"onboarding": "prompts/onboarding_system.md",
|
||||
}
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
291
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
291
autogpt_platform/backend/backend/api/features/chat/db.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = SafeJson(credentials)
|
||||
if successful_agent_runs is not None:
|
||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
||||
if successful_agent_schedules is not None:
|
||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
||||
if total_prompt_tokens is not None:
|
||||
data["totalPromptTokens"] = total_prompt_tokens
|
||||
if total_completion_tokens is not None:
|
||||
data["totalCompletionTokens"] = total_completion_tokens
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
sequence: int,
|
||||
content: str | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Run message create and session timestamp update in parallel for lower latency
|
||||
_, message = await asyncio.gather(
|
||||
PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
"""Get the total number of chat sessions for a user."""
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session and all its messages.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion of other
|
||||
users' sessions.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Build typed where clause with optional user_id validation
|
||||
where_clause: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No session deleted for {session_id} "
|
||||
f"(user_id validation: {user_id is not None})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
new_content: str,
|
||||
) -> bool:
|
||||
"""Update the content of a tool message in chat history.
|
||||
|
||||
Used by background tasks to update pending operation messages with final results.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
tool_call_id: The tool call ID to find the message.
|
||||
new_content: The new content to set.
|
||||
|
||||
Returns:
|
||||
True if a message was updated, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await PrismaChatMessage.prisma().update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": new_content,
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No message found to update for session {session_id}, "
|
||||
f"tool_call_id {tool_call_id}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update tool message for session {session_id}, "
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Self, cast
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# ===================== Chat data models ===================== #
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -55,33 +85,18 @@ class ChatMessage(BaseModel):
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
|
||||
return ChatMessage(
|
||||
role=prisma_message.role,
|
||||
content=prisma_message.content,
|
||||
name=prisma_message.name,
|
||||
tool_call_id=prisma_message.toolCallId,
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
)
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
@@ -89,9 +104,40 @@ class ChatSessionInfo(BaseModel):
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||
@staticmethod
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
@@ -113,10 +159,11 @@ class ChatSessionInfo(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
@@ -125,56 +172,6 @@ class ChatSessionInfo(BaseModel):
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
|
||||
class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||
if prisma_session.Messages is None:
|
||||
raise ValueError(
|
||||
f"Prisma session {prisma_session.id} is missing Messages relation"
|
||||
)
|
||||
|
||||
return cls(
|
||||
**ChatSessionInfo.from_db(prisma_session).model_dump(),
|
||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||
)
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
Searches backwards for the most recent assistant message (stopping at
|
||||
any user message boundary). If found, appends the tool_call to it.
|
||||
Otherwise creates a new assistant message with the tool_call.
|
||||
"""
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == "user":
|
||||
break
|
||||
if msg.role == "assistant":
|
||||
if not msg.tool_calls:
|
||||
msg.tool_calls = []
|
||||
msg.tool_calls.append(tool_call)
|
||||
return
|
||||
|
||||
self.messages.append(
|
||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
@@ -261,72 +258,43 @@ class ChatSession(ChatSessionInfo):
|
||||
name=message.name or "",
|
||||
)
|
||||
)
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive_assistant_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Merge consecutive assistant messages into single messages.
|
||||
|
||||
Long-running tool flows can create split assistant messages: one with
|
||||
text content and another with tool_calls. Anthropic's API requires
|
||||
tool_result blocks to reference a tool_use in the immediately preceding
|
||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev = result[-1]
|
||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||
|
||||
curr_content = curr.get("content") or ""
|
||||
if curr_content:
|
||||
prev_content = prev.get("content") or ""
|
||||
prev["content"] = (
|
||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||
)
|
||||
|
||||
curr_tool_calls = curr.get("tool_calls")
|
||||
if curr_tool_calls:
|
||||
prev_tool_calls = prev.get("tool_calls")
|
||||
prev["tool_calls"] = (
|
||||
list(prev_tool_calls) + list(curr_tool_calls)
|
||||
if prev_tool_calls
|
||||
else list(curr_tool_calls)
|
||||
)
|
||||
return result
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
"""Invalidate a chat session from Redis cache.
|
||||
|
||||
@@ -342,6 +310,80 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -373,7 +415,7 @@ async def get_chat_session(
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
@@ -389,7 +431,7 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
@@ -397,44 +439,6 @@ async def get_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
session = await chat_db().get_chat_session(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Loaded session {session_id} from DB: "
|
||||
f"has_messages={bool(session.messages)}, "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
@@ -454,35 +458,25 @@ async def upsert_chat_session(
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(
|
||||
session,
|
||||
existing_message_count,
|
||||
skip_existence_check=existing_message_count > 0,
|
||||
)
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -503,107 +497,6 @@ async def upsert_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession,
|
||||
existing_message_count: int,
|
||||
*,
|
||||
skip_existence_check: bool = False,
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database.
|
||||
|
||||
Args:
|
||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||
and assume the session row already exists. Saves one DB round trip
|
||||
for incremental saves during streaming.
|
||||
"""
|
||||
db = chat_db()
|
||||
|
||||
if not skip_existence_check:
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
lock = await _get_session_lock(session_id)
|
||||
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist message to session {session_id}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
@@ -616,7 +509,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -628,7 +521,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -639,16 +532,20 @@ async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ChatSessionInfo], int]:
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
db = chat_db()
|
||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
@@ -666,7 +563,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -683,89 +580,38 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
try:
|
||||
from .tools.agent_browser import close_browser_session
|
||||
|
||||
await close_browser_session(session_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Browser cleanup for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
)
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
119
autogpt_platform/backend/backend/api/features/chat/model_test.py
Normal file
119
autogpt_platform/backend/backend/api/features/chat/model_test.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import pytest
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you?", role="user"),
|
||||
ChatMessage(
|
||||
content="I'm fine, thank you!",
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "New York"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
content="I'm using the tool to get the weather",
|
||||
role="tool",
|
||||
tool_call_id="t123",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_serialization_deserialization():
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s.messages = messages
|
||||
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
|
||||
serialized = s.model_dump_json()
|
||||
s2 = ChatSession.model_validate_json(serialized)
|
||||
assert s2.model_dump() == s.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 == s
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
# Load from DB (cache was cleared)
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
user_id=s.user_id,
|
||||
)
|
||||
|
||||
assert s2 is not None, "Session not found after loading from DB"
|
||||
assert len(s2.messages) == len(
|
||||
s.messages
|
||||
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
|
||||
|
||||
# Verify all roles are present
|
||||
roles = [m.role for m in s2.messages]
|
||||
assert "user" in roles, f"User message missing. Roles found: {roles}"
|
||||
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
|
||||
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
|
||||
|
||||
# Verify message content
|
||||
for orig, loaded in zip(s.messages, s2.messages):
|
||||
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
|
||||
assert (
|
||||
orig.content == loaded.content
|
||||
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
|
||||
if orig.tool_calls:
|
||||
assert (
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
@@ -5,18 +5,11 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
@@ -25,10 +18,6 @@ class ResponseType(str, Enum):
|
||||
START = "start"
|
||||
FINISH = "finish"
|
||||
|
||||
# Step lifecycle (one LLM API call within a message)
|
||||
START_STEP = "start-step"
|
||||
FINISH_STEP = "finish-step"
|
||||
|
||||
# Text streaming
|
||||
TEXT_START = "text-start"
|
||||
TEXT_DELTA = "text-delta"
|
||||
@@ -52,8 +41,7 @@ class StreamBaseResponse(BaseModel):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
json_str = self.model_dump_json(exclude_none=True)
|
||||
return f"data: {json_str}\n\n"
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
# ========== Message Lifecycle ==========
|
||||
@@ -64,19 +52,11 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
sessionId: str | None = Field(
|
||||
taskId: str | None = Field(
|
||||
default=None,
|
||||
description="Session ID for SSE reconnection.",
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
|
||||
data: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"messageId": self.messageId,
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
"""End of message/stream."""
|
||||
@@ -84,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
|
||||
type: ResponseType = ResponseType.FINISH
|
||||
|
||||
|
||||
class StreamStartStep(StreamBaseResponse):
|
||||
"""Start of a step (one LLM API call within a message).
|
||||
|
||||
The AI SDK uses this to add a step-start boundary to message.parts,
|
||||
enabling visual separation between multiple LLM calls in a single message.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.START_STEP
|
||||
|
||||
|
||||
class StreamFinishStep(StreamBaseResponse):
|
||||
"""End of a step (one LLM API call within a message).
|
||||
|
||||
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
||||
so the next LLM call in a tool-call continuation starts with clean state.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.FINISH_STEP
|
||||
|
||||
|
||||
# ========== Text Streaming ==========
|
||||
|
||||
|
||||
@@ -151,16 +111,13 @@ class StreamToolInputAvailable(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
|
||||
|
||||
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Keep these for internal backend use
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
@@ -168,47 +125,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Truncate oversized outputs after construction."""
|
||||
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
"output": self.output,
|
||||
}
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
promptTokens: int = Field(..., description="Number of uncached prompt tokens")
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(
|
||||
..., description="Total number of tokens (raw, not weighted)"
|
||||
)
|
||||
cacheReadTokens: int = Field(
|
||||
default=0, description="Prompt tokens served from cache (10% cost)"
|
||||
)
|
||||
cacheCreationTokens: int = Field(
|
||||
default=0, description="Prompt tokens written to cache (25% cost)"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True)}\n\n"
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
@@ -221,18 +148,6 @@ class StreamError(StreamBaseResponse):
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
||||
|
||||
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
||||
any extra fields like `code` or `details`.
|
||||
"""
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"errorText": self.errorText,
|
||||
}
|
||||
return f"data: {json_dumps(data)}\n\n"
|
||||
|
||||
|
||||
class StreamHeartbeat(StreamBaseResponse):
|
||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,379 +0,0 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, and suggested prompts."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
},
|
||||
)
|
||||
# Should get past validation — 200 streaming response expected
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [
|
||||
valid_id,
|
||||
"not-a-uuid",
|
||||
"../../../etc/passwd",
|
||||
"",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_get_business_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
return_value=None,
|
||||
):
|
||||
"""Mock get_business_understanding."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_business_understanding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_suggested_prompts_returns_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding and prompts gets them back."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||
|
||||
|
||||
def test_suggested_prompts_no_understanding(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with no understanding gets empty list."""
|
||||
_mock_get_business_understanding(mocker, return_value=None)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
|
||||
|
||||
def test_suggested_prompts_empty_prompts(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""User with understanding but no prompts gets empty list."""
|
||||
mock_understanding = MagicMock()
|
||||
mock_understanding.suggested_prompts = []
|
||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||
|
||||
response = client.get("/suggested-prompts")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"prompts": []}
|
||||
1889
autogpt_platform/backend/backend/api/features/chat/service.py
Normal file
1889
autogpt_platform/backend/backend/api/features/chat/service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,82 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
@@ -0,0 +1,704 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It uses Redis for all state management (no in-memory state), making
|
||||
pods stateless and horizontally scalable.
|
||||
|
||||
Architecture:
|
||||
- Redis Stream: Persists all messages for replay and real-time delivery
|
||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||
|
||||
Subscribers:
|
||||
1. Replay missed messages from Redis Stream (XREAD)
|
||||
2. Listen for live updates via blocking XREAD
|
||||
3. No in-memory state required on the subscribing pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
# Track listener tasks per subscriber queue for cleanup
|
||||
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
||||
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
||||
|
||||
# Timeout for putting chunks into subscriber queues (seconds)
|
||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||
QUEUE_PUT_TIMEOUT = 5.0
|
||||
|
||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||
# Returns 1 if status was updated, 0 if already completed/failed
|
||||
COMPLETE_TASK_SCRIPT = """
|
||||
local current = redis.call("HGET", KEYS[1], "status")
|
||||
if current == "running" then
|
||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{config.task_meta_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{config.task_stream_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{config.task_op_prefix}{operation_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous)
|
||||
tool_call_id: Tool call ID from the LLM
|
||||
tool_name: Name of the tool being executed
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store metadata in Redis
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id or "",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"operation_id": operation_id,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to Redis Stream.
|
||||
|
||||
All delivery is via Redis Streams - no in-memory state.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Write to Redis Stream for persistence and real-time delivery
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Set TTL on stream to match task metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish chunk for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
logger.debug(f"Task {task_id} not found in Redis")
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id:
|
||||
if user_id != task_user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} denied access to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
listener_task = asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||
)
|
||||
# Track listener task for cleanup on unsubscribe
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def _stream_listener(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
|
||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||
when messages are published during the gap between replay and subscription.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
"""
|
||||
queue_id = id(subscriber_queue)
|
||||
# Track the last successfully delivered message ID for recovery hints
|
||||
last_delivered_id = last_replayed_id
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
current_id = last_replayed_id
|
||||
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
if status and status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering finish event for task {task_id}"
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, "
|
||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
# Queue is completely stuck, nothing more we can do
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for task {task_id}, "
|
||||
"queue completely blocked"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing stream message: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||
raise # Re-raise to propagate cancellation
|
||||
except Exception as e:
|
||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||
logger.warning(
|
||||
f"Could not deliver finish event for task {task_id} after error"
|
||||
)
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
_listener_tasks.pop(queue_id, None)
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> bool:
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
This is idempotent - calling multiple times with the same task_id is safe.
|
||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
|
||||
Returns:
|
||||
True if task was newly marked completed, False if already completed/failed
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||
try:
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish finish event for task {task_id}: {e}. "
|
||||
"Listeners will detect completion via status polling."
|
||||
)
|
||||
|
||||
# Clean up local task reference if exists
|
||||
_local_tasks.pop(task_id, None)
|
||||
return True
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
"""Find a task by its operation ID.
|
||||
|
||||
Used by webhook callbacks to locate the task to update.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to search for
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
if not task_id:
|
||||
return None
|
||||
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
return await get_task(task_id_str)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID from Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
return ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
async def get_task_with_expiry_info(
|
||||
task_id: str,
|
||||
) -> tuple[ActiveTask | None, str | None]:
|
||||
"""Get a task by its ID with expiration detection.
|
||||
|
||||
Returns (task, error_code) where error_code is:
|
||||
- None if task found
|
||||
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
||||
- "TASK_NOT_FOUND" if neither exists
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask or None, error_code or None)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
# Check if stream still has data (metadata expired but stream hasn't)
|
||||
stream_len = await redis.xlen(stream_key)
|
||||
if stream_len > 0:
|
||||
return None, "TASK_EXPIRED"
|
||||
return None, "TASK_NOT_FOUND"
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
async def get_active_task_for_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[ActiveTask | None, str]:
|
||||
"""Get the active (running) task for a session, if any.
|
||||
|
||||
Scans Redis for tasks matching the session_id with status="running".
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
user_id: User ID for ownership validation (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||
"""
|
||||
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Scan Redis for task metadata keys
|
||||
cursor = 0
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task_session_id = meta.get("session_id", "")
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
task_id = meta.get("task_id", "")
|
||||
|
||||
if task_session_id == session_id and task_status == "running":
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
try:
|
||||
messages = await redis.xrevrange(stream_key, count=1)
|
||||
if messages:
|
||||
msg_id = messages[0][0]
|
||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get last message ID: {e}")
|
||||
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=task_session_id,
|
||||
user_id=task_user_id,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status="running",
|
||||
),
|
||||
last_id,
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return None, "0-0"
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||
|
||||
Args:
|
||||
chunk_data: Parsed JSON data from Redis
|
||||
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
# Map response types to their corresponding classes
|
||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||
ResponseType.START.value: StreamStart,
|
||||
ResponseType.FINISH.value: StreamFinish,
|
||||
ResponseType.TEXT_START.value: StreamTextStart,
|
||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||
ResponseType.ERROR.value: StreamError,
|
||||
ResponseType.USAGE.value: StreamUsage,
|
||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||
}
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||
|
||||
if chunk_class is None:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return chunk_class(**chunk_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Track the asyncio.Task for a task (local reference only).
|
||||
|
||||
This is just for cleanup purposes - the task state is in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to track
|
||||
"""
|
||||
_local_tasks[task_id] = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Clean up when a subscriber disconnects.
|
||||
|
||||
Cancels the XREAD-based listener task associated with this subscriber queue
|
||||
to prevent resource leaks.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
subscriber_queue: The subscriber's queue used to look up the listener task
|
||||
"""
|
||||
queue_id = id(subscriber_queue)
|
||||
listener_entry = _listener_tasks.pop(queue_id, None)
|
||||
|
||||
if listener_entry is None:
|
||||
logger.debug(
|
||||
f"No listener task found for task {task_id} queue {queue_id} "
|
||||
"(may have already completed)"
|
||||
)
|
||||
return
|
||||
|
||||
stored_task_id, listener_task = listener_entry
|
||||
|
||||
if stored_task_id != task_id:
|
||||
logger.warning(
|
||||
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
||||
f"found {stored_task_id}"
|
||||
)
|
||||
|
||||
if listener_task.done():
|
||||
logger.debug(f"Listener task for task {task_id} already completed")
|
||||
return
|
||||
|
||||
# Cancel the listener task
|
||||
listener_task.cancel()
|
||||
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||
except asyncio.CancelledError:
|
||||
# Expected - the task was successfully cancelled
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout waiting for listener task cancellation for task {task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
||||
@@ -1,43 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .fix_agent import FixAgentGraphTool
|
||||
from .get_agent_building_guide import GetAgentBuildingGuideTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .get_mcp_guide import GetMCPGuideTool
|
||||
from .manage_folders import (
|
||||
CreateFolderTool,
|
||||
DeleteFolderTool,
|
||||
ListFoldersTool,
|
||||
MoveAgentsToFolderTool,
|
||||
MoveFolderTool,
|
||||
UpdateFolderTool,
|
||||
)
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
@@ -46,8 +27,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,37 +40,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
# Folder management tools
|
||||
"create_folder": CreateFolderTool(),
|
||||
"list_folders": ListFoldersTool(),
|
||||
"update_folder": UpdateFolderTool(),
|
||||
"move_folder": MoveFolderTool(),
|
||||
"delete_folder": DeleteFolderTool(),
|
||||
"move_agents_to_folder": MoveAgentsToFolderTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
"get_agent_building_guide": GetAgentBuildingGuideTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
"browser_navigate": BrowserNavigateTool(),
|
||||
"browser_act": BrowserActTool(),
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
"create_feature_request": CreateFeatureRequestTool(),
|
||||
# Agent generation tools (local validation/fixing)
|
||||
"validate_agent_graph": ValidateAgentGraphTool(),
|
||||
"fix_agent_graph": FixAgentGraphTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
@@ -102,17 +56,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
"""
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
]
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
@@ -1,46 +1,22 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data import db as db_module
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_db_connected() -> None:
|
||||
"""Ensure the Prisma connection is alive on the current event loop.
|
||||
|
||||
On Python 3.11, the httpx transport inside Prisma can reference a stale
|
||||
(closed) event loop when session-scoped async fixtures are evaluated long
|
||||
after the initial ``server`` fixture connected Prisma. A cheap health-check
|
||||
followed by a reconnect fixes this without affecting other fixtures.
|
||||
"""
|
||||
try:
|
||||
await prisma.query_raw("SELECT 1")
|
||||
except Exception:
|
||||
_logger.info("Prisma connection stale – reconnecting")
|
||||
try:
|
||||
await db_module.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
await db_module.connect()
|
||||
|
||||
|
||||
def make_session(user_id: str):
|
||||
return ChatSession(
|
||||
@@ -55,19 +31,15 @@ def make_session(user_id: str):
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def setup_test_data(server):
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_test_data():
|
||||
"""
|
||||
Set up test data for run_agent tests:
|
||||
1. Create a test user
|
||||
2. Create a test graph (agent input -> agent output)
|
||||
3. Create a store listing and store listing version
|
||||
4. Approve the store listing version
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
await _ensure_db_connected()
|
||||
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
@@ -151,8 +123,8 @@ async def setup_test_data(server):
|
||||
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Test Agent",
|
||||
description="A simple test agent",
|
||||
@@ -161,10 +133,10 @@ async def setup_test_data(server):
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
|
||||
assert store_submission.listing_version_id is not None
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
# 4. Approve the store listing version
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval",
|
||||
@@ -178,19 +150,15 @@ async def setup_test_data(server):
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def setup_llm_test_data(server):
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_llm_test_data():
|
||||
"""
|
||||
Set up test data for LLM agent tests:
|
||||
1. Create a test user
|
||||
2. Create test OpenAI credentials for the user
|
||||
3. Create a test graph with input -> LLM block -> output
|
||||
4. Create and approve a store listing
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
await _ensure_db_connected()
|
||||
|
||||
key = getenv("OPENAI_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("OPENAI_API_KEY is not set")
|
||||
@@ -321,8 +289,8 @@ async def setup_llm_test_data(server):
|
||||
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="LLM Test Agent",
|
||||
description="An agent with LLM capabilities",
|
||||
@@ -330,9 +298,9 @@ async def setup_llm_test_data(server):
|
||||
categories=["testing", "ai"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.listing_version_id is not None
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for LLM agent",
|
||||
@@ -347,18 +315,14 @@ async def setup_llm_test_data(server):
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def setup_firecrawl_test_data(server):
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_firecrawl_test_data():
|
||||
"""
|
||||
Set up test data for Firecrawl agent tests (missing credentials scenario):
|
||||
1. Create a test user (WITHOUT Firecrawl credentials)
|
||||
2. Create a test graph with input -> Firecrawl block -> output
|
||||
3. Create and approve a store listing
|
||||
|
||||
Depends on ``server`` to ensure Prisma is connected.
|
||||
"""
|
||||
await _ensure_db_connected()
|
||||
|
||||
# 1. Create a test user
|
||||
user_data = {
|
||||
"sub": f"test-user-{uuid.uuid4()}",
|
||||
@@ -476,8 +440,8 @@ async def setup_firecrawl_test_data(server):
|
||||
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
|
||||
store_submission = await store_db.create_store_submission(
|
||||
user_id=user.id,
|
||||
graph_id=created_graph.id,
|
||||
graph_version=created_graph.version,
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug=unique_slug,
|
||||
name="Firecrawl Test Agent",
|
||||
description="An agent with Firecrawl integration (no credentials)",
|
||||
@@ -485,9 +449,9 @@ async def setup_firecrawl_test_data(server):
|
||||
categories=["testing", "scraping"],
|
||||
image_urls=["https://example.com/image.jpg"],
|
||||
)
|
||||
assert store_submission.listing_version_id is not None
|
||||
assert store_submission.store_listing_version_id is not None
|
||||
await store_db.review_store_submission(
|
||||
store_listing_version_id=store_submission.listing_version_id,
|
||||
store_listing_version_id=store_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Approved for testing",
|
||||
internal_comments="Test approval for Firecrawl agent",
|
||||
@@ -3,9 +3,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
user_id, input_data
|
||||
)
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
@@ -1,20 +1,24 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
customize_template,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
extract_uuids_from_text,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_by_ids,
|
||||
get_library_agents_for_generation,
|
||||
graph_to_json,
|
||||
json_to_graph,
|
||||
@@ -22,28 +26,33 @@ from .core import (
|
||||
search_marketplace_agents_for_generation,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .validation import AgentFixer, AgentValidator
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
"AgentFixer",
|
||||
"AgentValidator",
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
"customize_template",
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
"extract_uuids_from_text",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"get_agent_as_json",
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_by_ids",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
"search_marketplace_agents_for_generation",
|
||||
@@ -3,14 +3,20 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .helpers import UUID_RE_STR
|
||||
from .service import (
|
||||
customize_template_external,
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
is_external_service_configured,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,7 +78,38 @@ class DecompositionResult(TypedDict, total=False):
|
||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(UUID_RE_STR, re.IGNORECASE)
|
||||
def _to_dict_list(
|
||||
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
return None
|
||||
return [dict(a) for a in agents]
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _check_service_configured() -> None:
|
||||
"""Check if the external Agent Generator service is configured.
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
raise AgentGeneratorNotConfiguredError(
|
||||
"Agent Generator service is not configured. "
|
||||
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||
)
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_uuids_from_text(text: str) -> list[str]:
|
||||
@@ -108,9 +145,8 @@ async def get_library_agent_by_id(
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
db = library_db()
|
||||
try:
|
||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -127,7 +163,7 @@ async def get_library_agent_by_id(
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await db.get_library_agent(agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -154,36 +190,6 @@ async def get_library_agent_by_id(
|
||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||
|
||||
|
||||
async def get_library_agents_by_ids(
|
||||
user_id: str,
|
||||
agent_ids: list[str],
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Fetch multiple library agents by their IDs.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary for found agents (silently skips not found)
|
||||
"""
|
||||
agents: list[LibraryAgentSummary] = []
|
||||
for agent_id in agent_ids:
|
||||
try:
|
||||
agent = await get_library_agent_by_id(user_id, agent_id)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.debug(f"Fetched library agent by ID: {agent['name']}")
|
||||
else:
|
||||
logger.warning(f"Library agent not found for ID: {agent_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
|
||||
return agents
|
||||
|
||||
|
||||
async def get_library_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
@@ -208,17 +214,10 @@ async def get_library_agents_for_generation(
|
||||
Returns:
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
search_term = search_query.strip() if search_query else None
|
||||
if search_term and len(search_term) > 100:
|
||||
raise ValueError(
|
||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
||||
f"Please use a shorter, more specific search term."
|
||||
)
|
||||
|
||||
try:
|
||||
response = await library_db().list_library_agents(
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
include_executions=True,
|
||||
@@ -272,16 +271,9 @@ async def search_marketplace_agents_for_generation(
|
||||
Returns:
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
search_term = search_query.strip()
|
||||
if len(search_term) > 100:
|
||||
raise ValueError(
|
||||
f"Search query is too long ({len(search_term)} chars, max 100). "
|
||||
f"Please use a shorter, more specific search term."
|
||||
)
|
||||
|
||||
try:
|
||||
response = await store_db().get_store_agents(
|
||||
search_query=search_term,
|
||||
response = await store_db.get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
)
|
||||
@@ -294,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await graph_db().get_store_listed_graphs(graph_ids)
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
@@ -432,7 +424,7 @@ def extract_search_terms_from_steps(
|
||||
async def enrich_library_agents_from_steps(
|
||||
user_id: str,
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
|
||||
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||
exclude_graph_id: str | None = None,
|
||||
include_marketplace: bool = True,
|
||||
max_additional_results: int = 10,
|
||||
@@ -456,7 +448,7 @@ async def enrich_library_agents_from_steps(
|
||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
if not search_terms:
|
||||
return list(existing_agents)
|
||||
return existing_agents
|
||||
|
||||
existing_ids: set[str] = set()
|
||||
existing_names: set[str] = set()
|
||||
@@ -516,6 +508,79 @@ async def enrich_library_agents_from_steps(
|
||||
return all_agents
|
||||
|
||||
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
DecompositionResult with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
result = await decompose_goal_external(
|
||||
description, context, _to_dict_list(library_agents)
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams
|
||||
completion notification)
|
||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||
and SSE delivery)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||
)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
@@ -595,10 +660,7 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any],
|
||||
user_id: str,
|
||||
is_update: bool = False,
|
||||
folder_id: str | None = None,
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
@@ -606,16 +668,14 @@ async def save_agent_to_library(
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
folder_id: Optional folder ID to place the agent in
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
graph = json_to_graph(agent_json)
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id, folder_id=folder_id)
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
@@ -675,14 +735,12 @@ async def get_agent_as_json(
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
db = graph_db()
|
||||
|
||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||
graph = await db.get_graph(
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -692,3 +750,76 @@ async def get_agent_as_json(
|
||||
return None
|
||||
|
||||
return graph_to_json(graph)
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Generating the patch
|
||||
- Applying the patch
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
|
||||
async def customize_template(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Customize a template/marketplace agent using natural language.
|
||||
|
||||
This is used when users want to modify a template or marketplace agent
|
||||
to fit their specific needs before adding it to their library.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Understanding the modification request
|
||||
- Applying changes to the template
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for customize_template")
|
||||
return await customize_template_external(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
@@ -0,0 +1,498 @@
|
||||
"""External Agent Generator service client.
|
||||
|
||||
This module provides a client for communicating with the external Agent Generator
|
||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
error_type: str = "unknown",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a standardized error response dict.
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error_type: Machine-readable error type
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Error dict with type="error" and error details
|
||||
"""
|
||||
response: dict[str, Any] = {
|
||||
"type": "error",
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
}
|
||||
if details:
|
||||
response["details"] = details
|
||||
return response
|
||||
|
||||
|
||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
"""Classify an HTTP error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The HTTP status error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
status = e.response.status_code
|
||||
if status == 429:
|
||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||
elif status == 503:
|
||||
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||
elif status == 504 or status == 408:
|
||||
return "timeout", f"Agent Generator timed out: {e}"
|
||||
else:
|
||||
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
"""Classify a request error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The request error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return "timeout", f"Agent Generator request timed out: {e}"
|
||||
elif "connect" in error_str:
|
||||
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||
else:
|
||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def _get_settings() -> Settings:
|
||||
"""Get or create settings singleton."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def is_external_service_configured() -> bool:
|
||||
"""Check if external Agent Generator service is configured."""
|
||||
settings = _get_settings()
|
||||
return bool(settings.config.agentgenerator_host)
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL for the external service."""
|
||||
settings = _get_settings()
|
||||
host = settings.config.agentgenerator_host
|
||||
port = settings.config.agentgenerator_port
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client for the external service."""
|
||||
global _client
|
||||
if _client is None:
|
||||
settings = _get_settings()
|
||||
_client = httpx.AsyncClient(
|
||||
base_url=_get_base_url(),
|
||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator decomposition failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
if response_type == "instructions":
|
||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||
elif response_type == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
elif response_type == "unachievable_goal":
|
||||
return {
|
||||
"type": "unachievable_goal",
|
||||
"reason": data.get("reason"),
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "vague_goal":
|
||||
return {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "error":
|
||||
# Pass through error from the service
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return _create_error_response(
|
||||
f"Unknown response type from Agent Generator: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator patch generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
Returns:
|
||||
List of block info dicts or None on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/api/blocks")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error("External service returned error getting blocks")
|
||||
return None
|
||||
|
||||
return data.get("blocks", [])
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error getting blocks from external service: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def health_check() -> bool:
|
||||
"""Check if the external service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not is_external_service_configured():
|
||||
return False
|
||||
|
||||
client = _get_client()
|
||||
|
||||
try:
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||
except Exception as e:
|
||||
logger.warning(f"External agent generator health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def close_client() -> None:
|
||||
"""Close the HTTP client."""
|
||||
global _client
|
||||
if _client is not None:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
@@ -5,15 +5,15 @@ import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
|
||||
from .models import (
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
@@ -34,7 +34,6 @@ class AgentOutputInput(BaseModel):
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
wait_if_running: int = Field(default=0, ge=0, le=300)
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
@@ -118,11 +117,6 @@ class AgentOutputTool(BaseTool):
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
|
||||
Wait for completion (optional):
|
||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
||||
If the execution is running/queued, waits up to this many seconds for completion.
|
||||
Returns current status on timeout. If already finished, returns immediately.
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -152,13 +146,6 @@ class AgentOutputTool(BaseTool):
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait if execution is still running (0-300). "
|
||||
"If running, waits for completion. Returns current state on timeout."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
@@ -178,12 +165,10 @@ class AgentOutputTool(BaseTool):
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
@@ -197,7 +182,7 @@ class AgentOutputTool(BaseTool):
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
@@ -209,7 +194,7 @@ class AgentOutputTool(BaseTool):
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await lib_db.list_library_agents(
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
@@ -238,20 +223,14 @@ class AgentOutputTool(BaseTool):
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
include_running: bool = False,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
|
||||
Args:
|
||||
include_running: If True, also look for running/queued executions (for waiting)
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await exec_db.get_graph_execution(
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -260,25 +239,11 @@ class AgentOutputTool(BaseTool):
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Determine which statuses to query
|
||||
statuses = [ExecutionStatus.COMPLETED]
|
||||
if include_running:
|
||||
statuses.extend(
|
||||
[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.FAILED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
]
|
||||
)
|
||||
|
||||
# Get executions with time filters
|
||||
executions = await exec_db.get_graph_executions(
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=statuses,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
@@ -289,7 +254,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -297,7 +262,7 @@ class AgentOutputTool(BaseTool):
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -345,33 +310,10 @@ class AgentOutputTool(BaseTool):
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
# Build appropriate message based on execution status
|
||||
if execution.status == ExecutionStatus.COMPLETED:
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
elif execution.status == ExecutionStatus.FAILED:
|
||||
message = f"Execution for agent '{agent.name}' failed"
|
||||
elif execution.status == ExecutionStatus.TERMINATED:
|
||||
message = f"Execution for agent '{agent.name}' was terminated"
|
||||
elif execution.status == ExecutionStatus.REVIEW:
|
||||
message = (
|
||||
f"Execution for agent '{agent.name}' is awaiting human review. "
|
||||
"The user needs to approve it before it can continue."
|
||||
)
|
||||
elif execution.status in (
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
message = (
|
||||
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
|
||||
"Results may be incomplete. Use wait_if_running to wait for completion."
|
||||
)
|
||||
else:
|
||||
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f" Showing latest of {len(available_executions)} matching executions."
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
@@ -438,7 +380,7 @@ class AgentOutputTool(BaseTool):
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db().get_graph_execution(
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -450,7 +392,7 @@ class AgentOutputTool(BaseTool):
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db().get_library_agent_by_graph_id(
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
@@ -486,17 +428,13 @@ class AgentOutputTool(BaseTool):
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Check if we should wait for running executions
|
||||
wait_timeout = input_data.wait_if_running
|
||||
|
||||
# Fetch execution(s) - include running if we're going to wait
|
||||
# Fetch execution(s)
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
include_running=wait_timeout > 0,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
@@ -505,17 +443,4 @@ class AgentOutputTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If we have an execution that's still running and we should wait
|
||||
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
|
||||
logger.info(
|
||||
f"Execution {execution.id} is {execution.status}, "
|
||||
f"waiting up to {wait_timeout}s for completion"
|
||||
)
|
||||
execution = await wait_for_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=execution.id,
|
||||
timeout_seconds=wait_timeout,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
@@ -1,15 +1,11 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -29,24 +25,92 @@ _UUID_PATTERN = re.compile(
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Keywords that should be treated as "list all" rather than a literal search
|
||||
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
Tries multiple lookup strategies:
|
||||
1. First by graph_id (AgentGraph primary key)
|
||||
2. Then by library agent ID (LibraryAgent primary key)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
source: SearchSource,
|
||||
session_id: str | None = None,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
|
||||
For library searches, keywords like "all", "*", "everything", or an empty
|
||||
query will list all agents without filtering.
|
||||
|
||||
Args:
|
||||
query: Search query string. Special keywords list all library agents.
|
||||
query: Search query string
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
@@ -54,11 +118,7 @@ async def search_agents(
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
# Normalize list-all keywords to empty string for library searches
|
||||
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
if source == "marketplace" and not query:
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
@@ -73,7 +133,7 @@ async def search_agents(
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
@@ -98,18 +158,28 @@ async def search_agents(
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
search_term = query or None
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
results = await library_db().list_library_agents(
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=search_term,
|
||||
page_size=50 if not query else 10,
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
@@ -122,62 +192,42 @@ async def search_agents(
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if source == "marketplace":
|
||||
suggestions = [
|
||||
suggestions = (
|
||||
[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Let the user know they can "
|
||||
"try different keywords or browse the marketplace. Also let them "
|
||||
"know you can create a custom agent for them based on their needs."
|
||||
)
|
||||
elif not query:
|
||||
# User asked to list all but library is empty
|
||||
suggestions = [
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
]
|
||||
no_results_msg = (
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
)
|
||||
else:
|
||||
suggestions = [
|
||||
if source == "marketplace"
|
||||
else [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
no_results_msg = (
|
||||
f"No agents matching '{query}' found in your library. Let the "
|
||||
"user know you can create a custom agent for them based on "
|
||||
"their needs."
|
||||
)
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
)
|
||||
|
||||
if source == "marketplace":
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
elif not query:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
|
||||
else:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
title += (
|
||||
f"for '{query}'"
|
||||
if source == "marketplace"
|
||||
else f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
@@ -187,70 +237,3 @@ async def search_agents(
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
"""Convert a library agent model to an AgentInfo."""
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
Tries multiple lookup strategies:
|
||||
1. First by graph_id (AgentGraph primary key)
|
||||
2. Then by library agent ID (LibraryAgent primary key)
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
try:
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by graph_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
129
autogpt_platform/backend/backend/api/features/chat/tools/base.py
Normal file
129
autogpt_platform/backend/backend/api/features/chat/tools/base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Base classes and shared utilities for chat tools."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all chat tools."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name for OpenAI function calling."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Tool description for OpenAI."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""Tool parameters schema for OpenAI."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
"""Whether this tool is long-running and should execute in background.
|
||||
|
||||
Long-running tools (like agent generation) are executed via background
|
||||
tasks to survive SSE disconnections. The result is persisted to chat
|
||||
history and visible when the user refreshes.
|
||||
"""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous like "anon_123")
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
Returns:
|
||||
Pydantic response object
|
||||
|
||||
"""
|
||||
if self.requires_auth and not user_id:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Internal execution logic to be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
user_id: User ID (authenticated or anonymous)
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
Returns:
|
||||
Pydantic response object
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,335 @@
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||
error="decomposition_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="analyze the goal",
|
||||
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"decomposition_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if user_id and library_agents is not None:
|
||||
try:
|
||||
library_agents = await enrich_library_agents_from_steps(
|
||||
user_id=user_id,
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=library_agents,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
try:
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent generation is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if agent_json is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||
error="generation_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message=(
|
||||
"I wasn't able to create a valid agent for this request. "
|
||||
"The generated workflow had some structural issues. "
|
||||
"Please try simplifying your goal or breaking it into smaller steps."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"generation_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,337 @@
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
customize_template,
|
||||
get_user_message_for_error,
|
||||
graph_to_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "customize_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace or template agent using natural language. "
|
||||
"Takes an existing agent from the marketplace and modifies it based on "
|
||||
"the user's requirements before adding to their library."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The marketplace agent ID in format 'creator/slug' "
|
||||
"(e.g., 'autogpt/newsletter-writer'). "
|
||||
"Get this from find_agent results."
|
||||
),
|
||||
},
|
||||
"modifications": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of how to customize the agent. "
|
||||
"Be specific about what changes you want to make."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the customized agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Parse the agent ID to get creator/slug
|
||||
2. Fetch the template agent from the marketplace
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
error="invalid_agent_id_format",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except AgentNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the full agent graph
|
||||
try:
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Call customize_template
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent customization is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_service_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent. "
|
||||
"The agent generation service may be unavailable or timed out. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
message="Failed to customize the agent due to an unexpected response.",
|
||||
error="unexpected_response_type",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
agent_description = customized_agent.get("description", "")
|
||||
nodes = customized_agent.get("nodes")
|
||||
links = customized_agent.get("links")
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
f"The customized agent has {node_count} blocks. "
|
||||
f"Review it and call customize_agent with save=true to save it."
|
||||
),
|
||||
agent_json=customized_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Customized agent '{created_graph.name}' "
|
||||
f"(based on '{agent_details.agent_name}') "
|
||||
f"has been saved to your library!"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving customized agent: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to save the customized agent. Please try again.",
|
||||
error="save_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,284 @@
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates updates to the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent editing is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
error="update_generation_failed",
|
||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"update_generation_failed:{error_type}",
|
||||
details={
|
||||
"agent_id": agent_id,
|
||||
"changes": changes[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -0,0 +1,193 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, required_inputs, and input_schema."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Enrich results with full block information
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
block = get_block(block_id)
|
||||
|
||||
# Skip disabled blocks
|
||||
if block and not block.disabled:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -19,13 +19,9 @@ class FindLibraryAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for or list agents in the user's library. Use this to find "
|
||||
"agents the user has already added to their library, including agents "
|
||||
"they created or added from the marketplace. "
|
||||
"When creating agents with sub-agent composition, use this to get "
|
||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
||||
"needed for AgentExecutorBlock nodes. "
|
||||
"Omit the query to list all agents."
|
||||
"Search for agents in the user's library. Use this to find agents "
|
||||
"the user has already added to their library, including agents they "
|
||||
"created or added from the marketplace."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -35,13 +31,10 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Omit to list all agents in the library."
|
||||
),
|
||||
"description": "Search query to find agents by name or description.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -52,7 +45,7 @@ class FindLibraryAgentTool(BaseTool):
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
query=(kwargs.get("query") or "").strip(),
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
@@ -4,10 +4,13 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Pydantic models for tool responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
AGENT_PREVIEW = "agent_preview"
|
||||
AGENT_SAVED = "agent_saved"
|
||||
CLARIFICATION_NEEDED = "clarification_needed"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
# Workspace response types
|
||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
# Long-running operation types
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
"""Base model for all tool responses."""
|
||||
|
||||
type: ResponseType
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
# Agent discovery models
|
||||
class AgentInfo(BaseModel):
|
||||
"""Information about an agent."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
source: str = Field(description="marketplace or library")
|
||||
in_library: bool = False
|
||||
creator: str | None = None
|
||||
category: str | None = None
|
||||
rating: float | None = None
|
||||
runs: int | None = None
|
||||
is_featured: bool | None = None
|
||||
status: str | None = None
|
||||
can_access_graph: bool | None = None
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
|
||||
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENTS_FOUND
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
name: str = "agents_found"
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
"""Response when no agents found."""
|
||||
|
||||
type: ResponseType = ResponseType.NO_RESULTS
|
||||
suggestions: list[str] = []
|
||||
name: str = "no_results"
|
||||
|
||||
|
||||
# Agent details models
|
||||
class InputField(BaseModel):
|
||||
"""Input field specification."""
|
||||
|
||||
name: str
|
||||
type: str = "string"
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
options: list[Any] | None = None
|
||||
format: str | None = None
|
||||
|
||||
|
||||
class ExecutionOptions(BaseModel):
|
||||
"""Available execution options for an agent."""
|
||||
|
||||
manual: bool = True
|
||||
scheduled: bool = True
|
||||
webhook: bool = False
|
||||
|
||||
|
||||
class AgentDetails(BaseModel):
|
||||
"""Detailed agent information."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
in_library: bool = False
|
||||
inputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentDetailsResponse(ToolResponseBase):
|
||||
"""Response for get_details action."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_DETAILS
|
||||
agent: AgentDetails
|
||||
user_authenticated: bool = False
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Setup info models
|
||||
class UserReadiness(BaseModel):
|
||||
"""User readiness status."""
|
||||
|
||||
has_all_credentials: bool = False
|
||||
missing_credentials: dict[str, Any] = {}
|
||||
ready_to_run: bool = False
|
||||
|
||||
|
||||
class SetupInfo(BaseModel):
|
||||
"""Complete setup information."""
|
||||
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
requirements: dict[str, list[Any]] = Field(
|
||||
default_factory=lambda: {
|
||||
"credentials": [],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
)
|
||||
user_readiness: UserReadiness = Field(default_factory=UserReadiness)
|
||||
|
||||
|
||||
class SetupRequirementsResponse(ToolResponseBase):
|
||||
"""Response for validate action."""
|
||||
|
||||
type: ResponseType = ResponseType.SETUP_REQUIREMENTS
|
||||
setup_info: SetupInfo
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Execution models
|
||||
class ExecutionStartedResponse(ToolResponseBase):
|
||||
"""Response for run/schedule actions."""
|
||||
|
||||
type: ResponseType = ResponseType.EXECUTION_STARTED
|
||||
execution_id: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
library_agent_id: str | None = None
|
||||
library_agent_link: str | None = None
|
||||
status: str = "QUEUED"
|
||||
|
||||
|
||||
# Auth/error models
|
||||
class NeedLoginResponse(ToolResponseBase):
|
||||
"""Response when login is needed."""
|
||||
|
||||
type: ResponseType = ResponseType.NEED_LOGIN
|
||||
agent_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ErrorResponse(ToolResponseBase):
|
||||
"""Response for errors."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
error: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||
unrecognized_fields: list[str] = Field(
|
||||
description="List of input field names that were not recognized"
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
description="The agent's valid input schema for reference"
|
||||
)
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
execution_id: str
|
||||
status: str
|
||||
started_at: datetime | None = None
|
||||
ended_at: datetime | None = None
|
||||
outputs: dict[str, list[Any]]
|
||||
inputs_summary: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentOutputResponse(ToolResponseBase):
|
||||
"""Response for agent_output tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_OUTPUT
|
||||
agent_name: str
|
||||
agent_id: str
|
||||
library_agent_id: str | None = None
|
||||
library_agent_link: str | None = None
|
||||
execution: ExecutionOutputInfo | None = None
|
||||
available_executions: list[dict[str, Any]] | None = None
|
||||
total_executions: int = 0
|
||||
|
||||
|
||||
# Business understanding models
|
||||
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
"""Response for add_understanding tool."""
|
||||
|
||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||
updated_fields: list[str] = Field(default_factory=list)
|
||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Agent generation models
|
||||
class ClarifyingQuestion(BaseModel):
|
||||
"""A question that needs user clarification."""
|
||||
|
||||
question: str
|
||||
keyword: str
|
||||
example: str | None = None
|
||||
|
||||
|
||||
class AgentPreviewResponse(ToolResponseBase):
|
||||
"""Response for previewing a generated agent before saving."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_PREVIEW
|
||||
agent_json: dict[str, Any]
|
||||
agent_name: str
|
||||
description: str
|
||||
node_count: int
|
||||
link_count: int = 0
|
||||
|
||||
|
||||
class AgentSavedResponse(ToolResponseBase):
|
||||
"""Response when an agent is saved to the library."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_SAVED
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
library_agent_id: str
|
||||
library_agent_link: str
|
||||
agent_page_link: str # Link to the agent builder/editor page
|
||||
|
||||
|
||||
class ClarificationNeededResponse(ToolResponseBase):
|
||||
"""Response when the LLM needs more information from the user."""
|
||||
|
||||
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
title: str
|
||||
path: str
|
||||
section: str
|
||||
snippet: str # Short excerpt for UI display
|
||||
score: float
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
class DocSearchResultsResponse(ToolResponseBase):
|
||||
"""Response for search_docs tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||
results: list[DocSearchResult]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class DocPageResponse(ToolResponseBase):
|
||||
"""Response for get_doc_page tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_PAGE
|
||||
title: str
|
||||
path: str
|
||||
content: str # Full document content
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
# Block models
|
||||
class BlockInputFieldInfo(BaseModel):
|
||||
"""Information about a block input field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class BlockInfoSummary(BaseModel):
|
||||
"""Summary of a block for search results."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
"""Response for find_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_LIST
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
"""Response for run_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||
block_id: str
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Long-running operation models
|
||||
class OperationStartedResponse(ToolResponseBase):
|
||||
"""Response when a long-running operation has been started in the background.
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
"""Response stored in chat history while a long-running operation is executing.
|
||||
|
||||
This is persisted to the database so users see a pending state when they
|
||||
refresh before the operation completes.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
|
||||
Returned for idempotency when the same tool_call_id is requested again
|
||||
while the background task is still running.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||
consumer will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
@@ -5,13 +5,16 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import (
|
||||
track_agent_run_success,
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -21,15 +24,11 @@ from backend.util.timezone_utils import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .execution_utils import get_execution_outputs, wait_for_execution
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
AgentDetails,
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ExecutionOutputInfo,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
@@ -70,7 +69,6 @@ class RunAgentInput(BaseModel):
|
||||
schedule_name: str = ""
|
||||
cron: str = ""
|
||||
timezone: str = "UTC"
|
||||
wait_for_result: int = Field(default=0, ge=0, le=300)
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug",
|
||||
@@ -152,14 +150,6 @@ class RunAgentTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait for execution to complete (0-300). "
|
||||
"If >0, blocks until the execution finishes or times out. "
|
||||
"Returns execution outputs when complete."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
@@ -209,7 +199,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db().get_library_agent(
|
||||
library_agent = await library_db.get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
@@ -218,7 +208,9 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db().get_graph(
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
@@ -269,7 +261,7 @@ class RunAgentTool(BaseTool):
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
),
|
||||
@@ -354,7 +346,6 @@ class RunAgentTool(BaseTool):
|
||||
graph=graph,
|
||||
graph_credentials=graph_credentials,
|
||||
inputs=params.inputs,
|
||||
wait_for_result=params.wait_for_result,
|
||||
)
|
||||
|
||||
except NotFoundError as e:
|
||||
@@ -378,6 +369,22 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract inputs list from schema."""
|
||||
inputs_list = []
|
||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
||||
for field_name, field_schema in input_schema["properties"].items():
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in input_schema.get("required", []),
|
||||
}
|
||||
)
|
||||
return inputs_list
|
||||
|
||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||
"""Get available execution modes for the graph."""
|
||||
trigger_info = graph.trigger_setup_info
|
||||
@@ -391,7 +398,7 @@ class RunAgentTool(BaseTool):
|
||||
suffix: str,
|
||||
) -> str:
|
||||
"""Build a message describing available inputs for an agent."""
|
||||
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||
|
||||
@@ -438,9 +445,8 @@ class RunAgentTool(BaseTool):
|
||||
graph: GraphModel,
|
||||
graph_credentials: dict[str, CredentialsMetaInput],
|
||||
inputs: dict[str, Any],
|
||||
wait_for_result: int = 0,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent immediately, optionally waiting for completion."""
|
||||
"""Execute an agent immediately."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Check rate limits
|
||||
@@ -477,93 +483,6 @@ class RunAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
library_agent_link = f"/library/agents/{library_agent.id}"
|
||||
|
||||
# If wait_for_result is requested, wait for execution to complete
|
||||
if wait_for_result > 0:
|
||||
logger.info(
|
||||
f"Waiting up to {wait_for_result}s for execution {execution.id}"
|
||||
)
|
||||
completed = await wait_for_execution(
|
||||
user_id=user_id,
|
||||
graph_id=library_agent.graph_id,
|
||||
execution_id=execution.id,
|
||||
timeout_seconds=wait_for_result,
|
||||
)
|
||||
|
||||
if completed and completed.status == ExecutionStatus.COMPLETED:
|
||||
outputs = get_execution_outputs(completed)
|
||||
return AgentOutputResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' completed successfully. "
|
||||
f"View at {library_agent_link}."
|
||||
),
|
||||
session_id=session_id,
|
||||
agent_name=library_agent.name,
|
||||
agent_id=library_agent.graph_id,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
execution=ExecutionOutputInfo(
|
||||
execution_id=execution.id,
|
||||
status=completed.status.value,
|
||||
started_at=completed.started_at,
|
||||
ended_at=completed.ended_at,
|
||||
outputs=outputs or {},
|
||||
),
|
||||
)
|
||||
elif completed and completed.status == ExecutionStatus.FAILED:
|
||||
error_detail = completed.stats.error if completed.stats else None
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' execution failed. "
|
||||
f"View details at {library_agent_link}."
|
||||
),
|
||||
session_id=session_id,
|
||||
error=error_detail,
|
||||
)
|
||||
elif completed and completed.status == ExecutionStatus.TERMINATED:
|
||||
error_detail = completed.stats.error if completed.stats else None
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' execution was terminated. "
|
||||
f"View details at {library_agent_link}."
|
||||
),
|
||||
session_id=session_id,
|
||||
error=error_detail,
|
||||
)
|
||||
elif completed and completed.status == ExecutionStatus.REVIEW:
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' is awaiting human review. "
|
||||
f"The user can approve or reject inline. After approval, "
|
||||
f"the execution resumes automatically. Use view_agent_output "
|
||||
f"with execution_id='{execution.id}' to check the result."
|
||||
),
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_name=library_agent.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
status=ExecutionStatus.REVIEW.value,
|
||||
)
|
||||
else:
|
||||
status = completed.status.value if completed else "unknown"
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' is still {status} after "
|
||||
f"{wait_for_result}s. Check results later at "
|
||||
f"{library_agent_link}. "
|
||||
f"Use view_agent_output with wait_if_running to check again."
|
||||
),
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_name=library_agent.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
status=status,
|
||||
)
|
||||
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' execution started successfully. "
|
||||
@@ -618,7 +537,7 @@ class RunAgentTool(BaseTool):
|
||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||
|
||||
# Get user timezone
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||
|
||||
# Create schedule
|
||||
@@ -0,0 +1,373 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"Use the 'id' from find_block results and provide input_data "
|
||||
"matching the block's required_inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Use the 'required_inputs' field "
|
||||
"from find_block to see what fields are needed."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
if block.disabled:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' is disabled",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block, input_data
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, set(matched_credentials.keys())
|
||||
)
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
# This means:
|
||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Create unified execution context with all required fields
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_exec_id,
|
||||
graph_version=1, # Versions are 1-indexed
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
# Workspace with session scoping
|
||||
workspace_id=workspace.id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Prepare kwargs for block execution
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_exec_id,
|
||||
"node_exec_id": synthetic_node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1, # Versions are 1-indexed
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -5,17 +5,16 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await search().unified_hybrid_search(
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
@@ -3,18 +3,17 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,14 +37,13 @@ async def fetch_graph_from_store_slug(
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during lookup.
|
||||
"""
|
||||
sdb = store_db()
|
||||
try:
|
||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||
except NotFoundError:
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph = await sdb.get_available_graph(
|
||||
graph = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
@@ -119,7 +117,7 @@ def build_missing_credentials_from_graph(
|
||||
preserving all supported credential types for each field.
|
||||
"""
|
||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||
aggregated_fields = graph.aggregate_credentials_inputs()
|
||||
aggregated_fields = graph.regular_credentials_inputs
|
||||
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
@@ -210,13 +208,13 @@ async def get_or_create_library_agent(
|
||||
Returns:
|
||||
LibraryAgent instance
|
||||
"""
|
||||
existing = await library_db().get_library_agent_by_graph_id(
|
||||
existing = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
library_agents = await library_db().create_library_agent(
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
@@ -225,99 +223,6 @@ async def get_or_create_library_agent(
|
||||
return library_agents[0]
|
||||
|
||||
|
||||
async def match_credentials_to_requirements(
|
||||
user_id: str,
|
||||
requirements: dict[str, CredentialsFieldInfo],
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Match user's credentials against a dictionary of credential requirements.
|
||||
|
||||
This is the core matching logic shared by both graph and block credential matching.
|
||||
"""
|
||||
matched: dict[str, CredentialsMetaInput] = {}
|
||||
missing: list[CredentialsMetaInput] = []
|
||||
|
||||
if not requirements:
|
||||
return matched, missing
|
||||
|
||||
available_creds = await get_user_credentials(user_id)
|
||||
|
||||
for field_name, field_info in requirements.items():
|
||||
matching_cred = find_matching_credential(available_creds, field_info)
|
||||
|
||||
if matching_cred:
|
||||
try:
|
||||
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||
f"credential_id={matching_cred.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=f"{field_name} (validation failed: {e})",
|
||||
)
|
||||
)
|
||||
else:
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched, missing
|
||||
|
||||
|
||||
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||
"""Get all available credentials for a user."""
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
return await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
|
||||
def find_matching_credential(
|
||||
available_creds: list[Credentials],
|
||||
field_info: CredentialsFieldInfo,
|
||||
) -> Credentials | None:
|
||||
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||
for cred in available_creds:
|
||||
if cred.provider not in field_info.provider:
|
||||
continue
|
||||
if cred.type not in field_info.supported_types:
|
||||
continue
|
||||
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||
cred, field_info
|
||||
):
|
||||
continue
|
||||
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||
continue
|
||||
return cred
|
||||
return None
|
||||
|
||||
|
||||
def create_credential_meta_from_match(
|
||||
matching_cred: Credentials,
|
||||
) -> CredentialsMetaInput:
|
||||
"""Create a CredentialsMetaInput from a matched credential."""
|
||||
return CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
|
||||
|
||||
async def match_user_credentials_to_graph(
|
||||
user_id: str,
|
||||
graph: GraphModel,
|
||||
@@ -339,7 +244,7 @@ async def match_user_credentials_to_graph(
|
||||
missing_creds: list[str] = []
|
||||
|
||||
# Get aggregated credentials requirements from the graph
|
||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||
aggregated_creds = graph.regular_credentials_inputs
|
||||
logger.debug(
|
||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||
)
|
||||
@@ -360,7 +265,7 @@ async def match_user_credentials_to_graph(
|
||||
_,
|
||||
_,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider, type, scopes, and host/URL
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
@@ -375,10 +280,6 @@ async def match_user_credentials_to_graph(
|
||||
cred.type != "host_scoped"
|
||||
or _credential_is_for_host(cred, credential_requirements)
|
||||
)
|
||||
and (
|
||||
cred.provider != ProviderName.MCP
|
||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -430,6 +331,8 @@ def _credential_has_required_scopes(
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
@@ -449,22 +352,6 @@ def _credential_is_for_host(
|
||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||
|
||||
|
||||
def _credential_is_for_mcp_server(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
||||
if not requirements.discriminator_values:
|
||||
return True
|
||||
|
||||
server_url = (
|
||||
credential.metadata.get("mcp_server_url")
|
||||
if isinstance(credential, OAuth2Credentials)
|
||||
else None
|
||||
)
|
||||
return server_url in requirements.discriminator_values if server_url else False
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for chat tools utility functions."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
|
||||
def _make_regular_field() -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
|
||||
def test_build_missing_credentials_excludes_auto_creds():
|
||||
"""
|
||||
build_missing_credentials_from_graph() should use regular_credentials_inputs
|
||||
and thus exclude auto_credentials from the "missing" set.
|
||||
"""
|
||||
from backend.api.features.chat.tools.utils import (
|
||||
build_missing_credentials_from_graph,
|
||||
)
|
||||
|
||||
regular_field = _make_regular_field()
|
||||
|
||||
mock_graph = MagicMock()
|
||||
# regular_credentials_inputs should only return the non-auto field
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
|
||||
|
||||
# Should include the regular credential
|
||||
assert "github_api_key" in result
|
||||
# Should NOT include the auto_credential (not in regular_credentials_inputs)
|
||||
assert "google_oauth2" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_user_credentials_excludes_auto_creds():
|
||||
"""
|
||||
match_user_credentials_to_graph() should use regular_credentials_inputs
|
||||
and thus exclude auto_credentials from matching.
|
||||
"""
|
||||
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
|
||||
|
||||
regular_field = _make_regular_field()
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "test-graph"
|
||||
# regular_credentials_inputs returns only non-auto fields
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
# Mock the credentials manager to return no credentials
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
|
||||
) as MockCredsMgr:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_all_creds.return_value = []
|
||||
MockCredsMgr.return_value.store = mock_store
|
||||
|
||||
matched, missing = await match_user_credentials_to_graph(
|
||||
user_id="test-user", graph=mock_graph
|
||||
)
|
||||
|
||||
# No credentials available, so github should be missing
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
assert "github_api_key" in missing[0]
|
||||
@@ -0,0 +1,620 @@
|
||||
"""CoPilot tools for workspace file operations."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceFileListResponse(ToolResponseBase):
|
||||
"""Response containing list of workspace files."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||
files: list[WorkspaceFileInfoData]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
"""Response containing workspace file content (legacy, for small text files)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
download_url: str
|
||||
preview: str | None = None # First 500 chars for text files
|
||||
|
||||
|
||||
class WorkspaceWriteResponse(ToolResponseBase):
|
||||
"""Response after writing a file to workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
"""Response after deleting a file from workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||
file_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_workspace_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's workspace. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mimeType,
|
||||
size_bytes=f.sizeBytes,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
size_bytes=file_info.sizeBytes,
|
||||
download_url=download_url,
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WriteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for writing files to workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's workspace. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.sizeBytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for deleting files from workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user