Compare commits

..

7 Commits

Author SHA1 Message Date
claude[bot]
657190e759 fix(frontend): address latest CodeRabbit review suggestions
- Use valid sort value "runs" instead of undefined in MainSearchResultPage
  test defaultProps to match production default and satisfy type contract
- Remove redundant marketplacePage.goto() navigation in E2E test since
  the page is already at /marketplace after login

Co-authored-by: Ubbe <0ubbe@users.noreply.github.com>
2026-02-12 15:02:18 +00:00
claude[bot]
caabee9278 fix(frontend): address CodeRabbit review suggestions for marketplace tests
- Fix filename typo: supress → suppress and update imports
- Replace waitFor + getByText/getByRole with findByText/findByRole (idiomatic RTL async queries)
- Remove unnecessary comments in test files per coding guidelines
- Fix operator precedence with explicit parentheses in suppress helper
- Remove redundant `undefined as undefined` type casts
- Extract inline props to `interface Props` in MockOnboardingProvider
- Widen body type in create-500-handler from Record<string,unknown> to unknown
- Add isValidating reset in mock-supabase-auth helpers
- Add missing creators MSW handler in no-results tests
- Clean up vitest.setup.tsx: replace nested afterAll with module-scoped variable
- Fix lint errors: unused imports (act, matchesUrl) and unused params
- Fix formatting in custom-mutator.ts

Co-authored-by: Ubbe <0ubbe@users.noreply.github.com>
2026-02-12 14:36:34 +00:00
Otto
0fcaa63162 style(frontend): fix formatting in marketplace integration tests 2026-01-30 06:34:39 +00:00
Abhimanyu Yadav
6299045f98 Merge branch 'dev' into abhi/marketplace-integration-tests 2026-01-30 11:42:52 +05:30
Otto
24cd34ed3f refactor(frontend): reorganize marketplace integration tests into file-specific locations
- Split main.test.tsx files into dedicated test files:
  - rendering.test.tsx - Component rendering tests
  - auth-state.test.tsx - Authentication state tests
  - error-handling.test.tsx - API error handling tests

- Add new test files:
  - loading-state.test.tsx - Loading skeleton tests
  - empty-state.test.tsx - Empty data handling tests
  - no-results.test.tsx - Search with no results tests

Test coverage:
- MainMarketplacePage: 14 tests (5 files)
- MainAgentPage: 13 tests (3 files)
- MainCreatorPage: 10 tests (3 files)
- MainSearchResultPage: 11 tests (4 files)
- Total: 48 tests across 15 files
2026-01-30 06:11:53 +00:00
abhi1992002
876c6677de fix(frontend): enhance testing and error handling in marketplace components
### Changes 🏗️
- Updated `MainMarketplacePage` tests to include rendering checks for various sections and error handling for API failures.
- Improved `AgentInfo` component to filter out NaN values from version numbers.
- Modified `customMutator` to conditionally log errors based on the environment.
- Enhanced Vitest configuration for better integration testing setup.
- Refactored existing tests for marketplace agents and creators to focus on cross-page flows.

### Checklist 📋
- [x] Verified that all tests pass with the new changes.
- [x] Ensured comprehensive coverage for error handling scenarios in tests.
- [x] Updated documentation for testing practices in `CLAUDE.md`.
2026-01-23 12:26:00 +05:30
abhi1992002
3e3af45456 fix(frontend): update testing setup with @testing-library/jest-dom and happy-dom
### Changes 🏗️
- Removed `happy-dom` from `devDependencies` and added it back in a different section for clarity.
- Added `@testing-library/jest-dom` to `devDependencies` for improved testing assertions.
- Updated `tsconfig.json` to include types for `@testing-library/jest-dom`.
- Configured Vitest to enable global variables for testing.
- Imported `@testing-library/jest-dom` in the Vitest setup file for enhanced testing capabilities.

### Checklist 📋
- [x] Verified that all tests pass with the new setup.
- [x] Ensured that the testing environment is correctly configured for integration tests.
2026-01-23 10:07:36 +05:30
985 changed files with 41431 additions and 92734 deletions

View File

@@ -1,17 +0,0 @@
---
name: backend-check
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Backend Check
## Steps
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
2. **Fix** any remaining errors manually, re-run until clean
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`

View File

@@ -1,35 +0,0 @@
---
name: code-style
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
user-invocable: false
metadata:
author: autogpt-team
version: "1.0.0"
---
# Code Style
## Imports
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
## Typing
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
## Code Structure
- **List comprehensions** over manual loop-and-append.
- **Early return** — guard clauses first, avoid deep nesting.
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
## Review Checklist
Before finishing, always ask:
- Can any function be split into smaller pieces?
- Is there unnecessary nesting that an early return would eliminate?
- Can any loop be a comprehension?
- Is there a simpler way to express this logic?

View File

@@ -1,16 +0,0 @@
---
name: frontend-check
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Frontend Check
## Steps (in order)
1. **Format**: `pnpm format` — NEVER run individual formatters
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user

View File

@@ -1,29 +0,0 @@
---
name: new-block
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# New Block Creation
Read `docs/platform/block-sdk-guide.md` first for the full guide.
## Steps
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
- `async def run` that `yield`s output fields
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
5. **Format**: `poetry run format`
## Rules
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
- Use top-level imports, avoid duck typing
- Always use `for_block_output` for block outputs

View File

@@ -1,28 +0,0 @@
---
name: openapi-regen
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# OpenAPI Spec Regeneration
## Steps
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
```bash
cd autogpt_platform/backend && poetry run rest &
REST_PID=$!
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
cd ../frontend && pnpm generate:api:force
kill $REST_PID
pnpm types && pnpm lint && pnpm format
```
## Rules
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
- Don't manually edit files in `src/app/api/__generated__/`
- Generated hooks follow: `use{Method}{Version}{OperationName}`

View File

@@ -1,31 +0,0 @@
---
name: pr-create
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Create Pull Request
## Steps
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
5. **Fill out PR template** as the body — be thorough in the Changes section
6. **Format first** (if relevant changes exist):
- Backend: `cd autogpt_platform/backend && poetry run format`
- Frontend: `cd autogpt_platform/frontend && pnpm format`
- Fix any lint errors, then commit formatting changes before pushing
7. **Push**: `git push -u origin HEAD`
8. **Create PR**: `gh pr create --base dev`
9. **Output** the PR URL
## Rules
- Always target `dev` branch
- Do NOT run tests — CI will handle that
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`

View File

@@ -1,51 +0,0 @@
---
name: pr-review
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Review Comment Workflow
## Steps
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
2. **Fetch comments** (all three sources):
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
3. **Skip** comments already reacted to by PR author
4. **For each unreacted comment**:
- Read referenced code, make the fix (or reply if you disagree/need info)
- **Inline review comments** (`pulls/{N}/comments`):
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
- **PR conversation comments** (`issues/{N}/comments`):
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
- No threaded replies — post a new issue comment if needed
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
5. **Include autogpt-reviewer bot fixes** too
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
7. **Commit & push**
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
9. **Stay productive while CI runs** — don't idle. In priority order:
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
- Address any remaining comments
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
10. **If CI fails** — fix, go back to step 6
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
12. **Done** only when: all comments reacted AND CI is green.
## CRITICAL: Do Not Stop
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
## Rules
- One todo per comment
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
- React to every comment: +1 addressed, -1 disagreed (with explanation)

View File

@@ -1,45 +0,0 @@
---
name: worktree-setup
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
user-invocable: true
metadata:
author: autogpt-team
version: "1.0.0"
---
# Worktree Setup
## Preferred: Use Branchlet
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
```bash
npm install -g branchlet # install once
branchlet create -n <name> -s <source-branch> -b <new-branch>
branchlet list --json # list all worktrees
```
## Manual Fallback
If branchlet isn't available:
1. `git worktree add ../<RepoName><N> <branch-name>`
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
3. Install deps:
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
- `cd autogpt_platform/frontend && pnpm install`
## Running the App
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
```bash
for port in 8001 8002 8003 8005 8006 8007 8008; do
lsof -ti :$port | xargs kill -9 2>/dev/null || true
done
cd <worktree>/autogpt_platform/backend && poetry run app
```
## CoPilot Testing Gotcha
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.

View File

@@ -5,13 +5,42 @@
!docs/
# Platform - Libs
!autogpt_platform/autogpt_libs/
!autogpt_platform/autogpt_libs/autogpt_libs/
!autogpt_platform/autogpt_libs/pyproject.toml
!autogpt_platform/autogpt_libs/poetry.lock
!autogpt_platform/autogpt_libs/README.md
# Platform - Backend
!autogpt_platform/backend/
!autogpt_platform/backend/backend/
!autogpt_platform/backend/test/e2e_test_data.py
!autogpt_platform/backend/migrations/
!autogpt_platform/backend/schema.prisma
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/
!autogpt_platform/market/scripts.py
!autogpt_platform/market/schema.prisma
!autogpt_platform/market/pyproject.toml
!autogpt_platform/market/poetry.lock
!autogpt_platform/market/README.md
# Platform - Frontend
!autogpt_platform/frontend/
!autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
!autogpt_platform/frontend/README.md
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT
!classic/original_autogpt/autogpt/
@@ -35,38 +64,6 @@
# Classic - Frontend
!classic/frontend/build/web/
# Explicitly re-ignore unwanted files from whitelisted directories
# Note: These patterns MUST come after the whitelist rules to take effect
# Hidden files and directories (but keep frontend .env files needed for build)
**/.*
!autogpt_platform/frontend/.env
!autogpt_platform/frontend/.env.default
!autogpt_platform/frontend/.env.production
# Python artifacts
**/__pycache__/
**/*.pyc
**/*.pyo
**/.venv/
**/.ruff_cache/
**/.pytest_cache/
**/.coverage
**/htmlcov/
# Node artifacts
**/node_modules/
**/.next/
**/storybook-static/
**/playwright-report/
**/test-results/
# Build artifacts
**/dist/
**/build/
!autogpt_platform/frontend/src/**/build/
**/target/
# Logs and temp files
**/*.log
**/*.tmp
# Explicitly re-ignore some folders
.*
**/__pycache__

File diff suppressed because it is too large Load Diff

View File

@@ -107,7 +107,7 @@ jobs:
- if: github.event_name == 'push'
name: Log in to Docker hub
uses: docker/login-action@v4
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PASSWORD }}

View File

@@ -23,7 +23,7 @@ jobs:
uses: actions/checkout@v4
- name: Log in to Docker hub
uses: docker/login-action@v4
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_PASSWORD }}

View File

@@ -49,7 +49,7 @@ jobs:
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
uses: peter-evans/create-pull-request@v7
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.workflow_run.head_branch }}
fetch-depth: 0
@@ -40,51 +40,9 @@ jobs:
git checkout -b "$BRANCH_NAME"
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
# Backend Python/Poetry setup (so Claude can run linting/tests)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
- name: Get CI failure details
id: failure_details
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const run = await github.rest.actions.getWorkflowRun({

View File

@@ -30,7 +30,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -41,7 +41,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -77,15 +77,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
@@ -112,7 +124,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
@@ -297,7 +309,6 @@ jobs:
uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
allowed_bots: "dependabot[bot]"
claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: |

View File

@@ -40,7 +40,7 @@ jobs:
actions: read # Required for CI access
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -57,7 +57,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -93,15 +93,27 @@ jobs:
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Enable corepack
run: corepack enable
- name: Set up Node.js
uses: actions/setup-node@v6
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
node-version: "22"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
@@ -128,7 +140,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -58,11 +58,11 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -93,6 +93,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -27,7 +27,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -39,7 +39,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -89,7 +89,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -132,7 +132,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -7,10 +7,6 @@ on:
- "docs/integrations/**"
- "autogpt_platform/backend/backend/blocks/**"
concurrency:
group: claude-docs-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
claude-review:
# Only run for PRs from members/collaborators
@@ -27,7 +23,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
@@ -37,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -95,35 +91,5 @@ jobs:
3. Read corresponding documentation files to verify accuracy
4. Provide your feedback as a PR comment
## IMPORTANT: Comment Marker
Start your PR comment with exactly this HTML comment marker on its own line:
<!-- CLAUDE_DOCS_REVIEW -->
This marker is used to identify and replace your comment on subsequent runs.
Be constructive and specific. If everything looks good, say so!
If there are issues, explain what's wrong and suggest how to fix it.
- name: Delete old Claude review comments
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Get all comment IDs with our marker, sorted by creation date (oldest first)
COMMENT_IDS=$(gh api \
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
# Count comments
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
if [ "$COMMENT_COUNT" -gt 1 ]; then
# Delete all but the last (newest) comment
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
if [ -n "$COMMENT_ID" ]; then
echo "Deleting old review comment: $COMMENT_ID"
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
fi
done
else
echo "No old review comments to clean up"
fi

View File

@@ -28,7 +28,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 1
@@ -38,7 +38,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -25,7 +25,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
@@ -52,7 +52,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -17,7 +17,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name || 'master' }}
@@ -45,7 +45,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger deploy workflow
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DEPLOY_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure

View File

@@ -41,18 +41,13 @@ jobs:
ports:
- 6379:6379
rabbitmq:
image: rabbitmq:4.1.4
image: rabbitmq:3.12-management
ports:
- 5672:5672
- 15672:15672
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
options: >-
--health-cmd "rabbitmq-diagnostics -q ping"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 10s
clamav:
image: clamav/clamav-debian:latest
ports:
@@ -73,7 +68,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
@@ -93,7 +88,7 @@ jobs:
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: Check comment permissions and deployment status
id: check_status
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const commentBody = context.payload.comment.body.trim();
@@ -55,7 +55,7 @@ jobs:
- name: Post permission denied comment
if: steps.check_status.outputs.permission_denied == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -68,7 +68,7 @@ jobs:
- name: Get PR details for deployment
id: pr_details
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const pr = await github.rest.pulls.get({
@@ -82,7 +82,7 @@ jobs:
- name: Dispatch Deploy Event
if: steps.check_status.outputs.should_deploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -98,7 +98,7 @@ jobs:
- name: Post deploy success comment
if: steps.check_status.outputs.should_deploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -110,7 +110,7 @@ jobs:
- name: Dispatch Undeploy Event (from comment)
if: steps.check_status.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -126,7 +126,7 @@ jobs:
- name: Post undeploy success comment
if: steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -139,7 +139,7 @@ jobs:
- name: Check deployment status on PR close
id: check_pr_close
if: github.event_name == 'pull_request' && github.event.action == 'closed'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const comments = await github.rest.issues.listComments({
@@ -168,7 +168,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: peter-evans/repository-dispatch@v4
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.DISPATCH_TOKEN }}
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
@@ -187,7 +187,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({

View File

@@ -6,16 +6,10 @@ on:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
pull_request:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
merge_group:
workflow_dispatch:
@@ -32,31 +26,34 @@ jobs:
setup:
runs-on: ubuntu-latest
outputs:
components-changed: ${{ steps.filter.outputs.components }}
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Check for component changes
uses: dorny/paths-filter@v3
id: filter
- name: Set up Node.js
uses: actions/setup-node@v4
with:
filters: |
components:
- 'autogpt_platform/frontend/src/components/**'
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Install dependencies to populate cache
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
lint:
@@ -65,17 +62,24 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
@@ -86,27 +90,31 @@ jobs:
chromatic:
runs-on: ubuntu-latest
needs: setup
# Disabled: to re-enable, remove 'false &&' from the condition below
if: >-
false
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
&& needs.setup.outputs.components-changed == 'true'
# Only run on dev branch pushes or PRs targeting dev
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
@@ -121,20 +129,30 @@ jobs:
exitOnceUploaded: true
e2e_test:
name: end-to-end tests
runs-on: big-boi
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
- name: Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
@@ -142,125 +160,77 @@ jobs:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v4
with:
driver: docker-container
driver-opts: network=host
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
restore-keys: |
${{ runner.os }}-buildx-frontend-test-
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
- name: Run docker compose
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
env:
NEXT_PUBLIC_PW_TEST: true
DOCKER_BUILDKIT: 1
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
- name: Move cache
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
rm -rf /tmp/.buildx-cache
if [ -d "/tmp/.buildx-cache-new" ]; then
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
fi
- name: Set up Platform - Run migrations
- name: Wait for services to be ready
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
- name: Create E2E test data
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# First try to run the script from inside the container
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
echo "✅ Found e2e_test_data.py in container, running it..."
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
else
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
# Copy the script into the container and run it
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
echo "❌ Failed to copy script to container"
exit 1
}
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
fi
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Set up tests - Install dependencies
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
@@ -287,7 +257,7 @@ jobs:
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs
run: docker compose -f ../docker-compose.yml logs
integration_test:
runs-on: ubuntu-latest
@@ -295,19 +265,26 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Set up Node
uses: actions/setup-node@v6
- name: Restore dependencies cache
uses: actions/cache@v4
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile

View File

@@ -29,10 +29,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -44,7 +44,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
@@ -56,19 +56,19 @@ jobs:
run: pnpm install --frozen-lockfile
types:
runs-on: big-boi
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -85,10 +85,10 @@ jobs:
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -1,39 +0,0 @@
name: PR Overlap Detection
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- dev
- master
permissions:
contents: read
pull-requests: write
jobs:
check-overlaps:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Need full history for merge testing
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Configure git
run: |
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
- name: Run overlap detection
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Always succeed - this check informs contributors, it shouldn't block merging
continue-on-error: true
run: |
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}

View File

@@ -11,7 +11,7 @@ jobs:
steps:
# - name: Wait some time for all actions to start
# run: sleep 30
- uses: actions/checkout@v6
- uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Set up Python

View File

@@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
Add cache configuration to a resolved docker-compose file for all services
that have a build key, and ensure image names match what docker compose expects.
"""
import argparse
import yaml
DEFAULT_BRANCH = "dev"
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
def main():
parser = argparse.ArgumentParser(
description="Add cache config to a resolved compose file"
)
parser.add_argument(
"--source",
required=True,
help="Source compose file to read (should be output of `docker compose config`)",
)
parser.add_argument(
"--cache-from",
default="type=gha",
help="Cache source configuration",
)
parser.add_argument(
"--cache-to",
default="type=gha,mode=max",
help="Cache destination configuration",
)
for component in CACHE_BUILDS_FOR_COMPONENTS:
parser.add_argument(
f"--{component}-hash",
default="",
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
)
parser.add_argument(
"--git-ref",
default="",
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
)
args = parser.parse_args()
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
git_ref_scope = ""
if args.git_ref:
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
with open(args.source, "r") as f:
compose = yaml.safe_load(f)
# Get project name from compose file or default
project_name = compose.get("name", "autogpt_platform")
def get_image_name(dockerfile: str, target: str) -> str:
"""Generate image name based on Dockerfile folder and build target."""
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
if len(dockerfile_parts) >= 2:
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
else:
folder_name = "app"
return f"{project_name}-{folder_name}:{target}"
def get_build_key(dockerfile: str, target: str) -> str:
"""Generate a unique key for a Dockerfile+target combination."""
return f"{dockerfile}:{target}"
def get_component(dockerfile: str) -> str | None:
"""Get component name (frontend/backend) from dockerfile path."""
for component in CACHE_BUILDS_FOR_COMPONENTS:
if component in dockerfile:
return component
return None
# First pass: collect all services with build configs and identify duplicates
# Track which (dockerfile, target) combinations we've seen
build_key_to_first_service: dict[str, str] = {}
services_to_build: list[str] = []
services_to_dedupe: list[str] = []
for service_name, service_config in compose.get("services", {}).items():
if "build" not in service_config:
continue
build_config = service_config["build"]
dockerfile = build_config.get("dockerfile", "Dockerfile")
target = build_config.get("target", "default")
build_key = get_build_key(dockerfile, target)
if build_key not in build_key_to_first_service:
# First service with this build config - it will do the actual build
build_key_to_first_service[build_key] = service_name
services_to_build.append(service_name)
else:
# Duplicate - will just use the image from the first service
services_to_dedupe.append(service_name)
# Second pass: configure builds and deduplicate
modified_services = []
for service_name, service_config in compose.get("services", {}).items():
if "build" not in service_config:
continue
build_config = service_config["build"]
dockerfile = build_config.get("dockerfile", "Dockerfile")
target = build_config.get("target", "latest")
image_name = get_image_name(dockerfile, target)
# Set image name for all services (needed for both builders and deduped)
service_config["image"] = image_name
if service_name in services_to_dedupe:
# Remove build config - this service will use the pre-built image
del service_config["build"]
continue
# This service will do the actual build - add cache config
cache_from_list = []
cache_to_list = []
component = get_component(dockerfile)
if not component:
# Skip services that don't clearly match frontend/backend
continue
# Get the hash for this component
component_hash = getattr(args, f"{component}_hash")
# Scope format: platform-{component}-{target}-{hash|ref}
# Example: platform-backend-server-abc123
if "type=gha" in args.cache_from:
# 1. Primary: exact hash match (most specific)
if component_hash:
hash_scope = f"platform-{component}-{target}-{component_hash}"
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
# 2. Fallback: branch-based cache
if git_ref_scope:
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
# 3. Fallback: dev branch cache (for PRs/feature branches)
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
if "type=gha" in args.cache_to:
# Write to both hash-based and branch-based scopes
if component_hash:
hash_scope = f"platform-{component}-{target}-{component_hash}"
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
if git_ref_scope:
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
# Ensure we have at least one cache source/target
if not cache_from_list:
cache_from_list.append(args.cache_from)
if not cache_to_list:
cache_to_list.append(args.cache_to)
build_config["cache_from"] = cache_from_list
build_config["cache_to"] = cache_to_list
modified_services.append(service_name)
# Write back to the same file
with open(args.source, "w") as f:
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
for svc in modified_services:
svc_config = compose["services"][svc]
build_cfg = svc_config.get("build", {})
cache_from_list = build_cfg.get("cache_from", ["none"])
cache_to_list = build_cfg.get("cache_to", ["none"])
print(f" - {svc}")
print(f" image: {svc_config.get('image', 'N/A')}")
print(f" cache_from: {cache_from_list}")
print(f" cache_to: {cache_to_list}")
if services_to_dedupe:
print(
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
)
for svc in services_to_dedupe:
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
if __name__ == "__main__":
main()

3
.gitignore vendored
View File

@@ -180,6 +180,3 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next
# Implementation plans (generated by AI agents)
plans/

1
.nvmrc
View File

@@ -1 +0,0 @@
22

View File

@@ -1,10 +1,3 @@
default_install_hook_types:
- pre-commit
- pre-push
- post-checkout
default_stages: [pre-commit]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
@@ -24,7 +17,6 @@ repos:
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
files: ^autogpt_platform/
exclude: pnpm-lock\.yaml$
stages: [pre-push]
- repo: local
@@ -34,106 +26,49 @@ repos:
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Backend
alias: poetry-install-platform-backend
entry: poetry -C autogpt_platform/backend install
# include autogpt_libs source (since it's a path dependency)
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/backend install
'
always_run: true
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Libs
alias: poetry-install-platform-libs
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/autogpt_libs install
'
always_run: true
entry: poetry -C autogpt_platform/autogpt_libs install
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: pnpm-install
name: Check & Install dependencies - AutoGPT Platform - Frontend
alias: pnpm-install-platform-frontend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
pnpm --prefix autogpt_platform/frontend install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
entry: poetry -C classic/original_autogpt install
# include forge source (since it's a path dependency)
always_run: true
files: ^classic/(original_autogpt|forge)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
entry: poetry -C classic/forge install
files: ^classic/forge/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
'
always_run: true
entry: poetry -C classic/benchmark install
files: ^classic/benchmark/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: local
# For proper type checking, Prisma client must be up-to-date.
@@ -141,54 +76,12 @@ repos:
- id: prisma-generate
name: Prisma Generate - AutoGPT Platform - Backend
alias: prisma-generate-platform-backend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
cd autogpt_platform/backend
&& poetry run prisma generate
&& poetry run gen-prisma-stub
'
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
# include everything that triggers poetry install + the prisma schema
always_run: true
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: export-api-schema
name: Export API schema - AutoGPT Platform - Backend -> Frontend
alias: export-api-schema-platform
entry: >
bash -c '
cd autogpt_platform/backend
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
&& cd ../frontend
&& pnpm prettier --write ./src/app/api/openapi.json
'
files: ^autogpt_platform/backend/
language: system
pass_filenames: false
- id: generate-api-client
name: Generate API client - AutoGPT Platform - Frontend
alias: generate-api-client-platform-frontend
entry: >
bash -c '
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
else
git diff --quiet HEAD -- "$SCHEMA" && exit 0
fi;
cd autogpt_platform/frontend && pnpm generate:api
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2

View File

@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
### Updated Setup Instructions:
We've moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -1,3 +1,2 @@
*.ignore.*
*.ign.*
.application.logs
*.ign.*

View File

@@ -45,11 +45,6 @@ AutoGPT Platform is a monorepo containing:
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Branching Strategy
- **`dev`** is the main development branch. All PRs should target `dev`.
- **`master`** is the production branch. Only used for production releases.
### Creating Pull Requests
- Create the PR against the `dev` branch of the repository.

File diff suppressed because it is too large Load Diff

View File

@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
cryptography = "^46.0"
cryptography = "^45.0"
expiringdict = "^1.2.2"
fastapi = "^0.128.7"
google-cloud-logging = "^3.13.0"
launchdarkly-server-sdk = "^9.15.0"
pydantic = "^2.12.5"
pydantic-settings = "^2.12.0"
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.28.0"
uvicorn = "^0.40.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]
pyright = "^1.1.408"
pyright = "^1.1.404"
pytest = "^8.4.1"
pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1"
pytest-cov = "^7.0.0"
ruff = "^0.15.0"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
pytest-cov = "^6.2.1"
ruff = "^0.12.11"
[build-system]
requires = ["poetry-core"]

View File

@@ -104,12 +104,6 @@ TWITTER_CLIENT_SECRET=
# Make a new workspace for your OAuth APP -- trust me
# https://linear.app/settings/api/applications/new
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
LINEAR_API_KEY=
# Linear project and team IDs for the feature request tracker.
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
# and in team settings. Used by the chat copilot to file and search feature requests.
LINEAR_FEATURE_REQUEST_PROJECT_ID=
LINEAR_FEATURE_REQUEST_TEAM_ID=
LINEAR_CLIENT_ID=
LINEAR_CLIENT_SECRET=
@@ -158,7 +152,6 @@ REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
ELEVENLABS_API_KEY=
# Data & Search Services
E2B_API_KEY=
@@ -190,8 +183,5 @@ ZEROBOUNCE_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Tally Form Integration (pre-populate business understanding on signup)
TALLY_API_KEY=
# Other Services
AUTOMOD_API_KEY=

View File

@@ -19,6 +19,3 @@ load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
migrations/*/rollback*.sql
# Workspace files
workspaces/

View File

@@ -1,5 +1,3 @@
# ============================ DEPENDENCY BUILDER ============================ #
FROM debian:13-slim AS builder
# Set environment variables
@@ -53,106 +51,58 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
# =============================== DB MIGRATOR =============================== #
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
FROM debian:13-slim AS migrate
WORKDIR /app/autogpt_platform/backend
ENV DEBIAN_FRONTEND=noninteractive
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy Node.js from builder (needed for Prisma CLI)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
# Copy Prisma binaries
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install prisma-client-py directly (much smaller than copying full venv)
RUN pip3 install prisma>=0.15.0 --break-system-packages
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
COPY autogpt_platform/backend/migrations ./migrations
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
FROM debian:13-slim AS server_dependencies
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool (fallback when E2B is not configured).
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
# Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
# Copy only necessary files from builder
COPY --from=builder /app /app
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma and agent-browser.
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
# COPY resolves them to regular files, breaking require() paths. Recreate as
# proper symlinks so npm/npx can find their modules.
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
RUN apt-get update && apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1 \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g agent-browser \
&& agent-browser install \
&& rm -rf /tmp/* /root/.npm
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
RUN mkdir -p /app/autogpt_platform/autogpt_libs
RUN mkdir -p /app/autogpt_platform/backend
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
FROM server_dependencies AS migrate
# Copy dependency files + autogpt_libs (path dependency)
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
# Copy backend code + docs (for Copilot docs search)
COPY autogpt_platform/backend ./
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
COPY docs /app/docs
# Install the project package to create entry point scripts in .venv/bin/
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
poetry install --no-ansi --only-root
RUN poetry install --no-ansi --only-root
ENV PORT=8000
CMD ["rest"]
CMD ["poetry", "run", "rest"]

View File

@@ -1,9 +1,4 @@
"""Common test fixtures for server tests.
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
setup_test_user, and setup_admin_user are defined in the parent conftest.py
(backend/conftest.py) and are available here automatically.
"""
"""Common test fixtures for server tests."""
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture
def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing."""

View File

@@ -88,23 +88,20 @@ async def require_auth(
)
def require_permission(*permissions: APIKeyPermission):
def require_permission(permission: APIKeyPermission):
"""
Dependency function for checking required permissions.
All listed permissions must be present.
Dependency function for checking specific permissions
(works with API keys and OAuth tokens)
"""
async def check_permissions(
async def check_permission(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
missing = [p for p in permissions if p not in auth.scopes]
if missing:
if permission not in auth.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission(s): "
f"{', '.join(p.value for p in missing)}",
detail=f"Missing required permission: {permission.value}",
)
return auth
return check_permissions
return check_permission

View File

@@ -1,7 +1,7 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Optional, Sequence
from typing import Annotated, Any, Literal, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -9,17 +9,15 @@ from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import backend.api.features.store.cache as store_cache
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
import backend.blocks
from backend.api.external.middleware import require_auth, require_permission
import backend.data.block
from backend.api.external.middleware import require_permission
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Settings
from .integrations import integrations_router
@@ -69,7 +67,7 @@ async def get_user_info(
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.blocks.get_blocks().values()]
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled]
@@ -85,7 +83,7 @@ async def execute_graph_block(
require_permission(APIKeyPermission.EXECUTE_BLOCK)
),
) -> CompletedBlockOutput:
obj = backend.blocks.get_block(block_id)
obj = backend.data.block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
if obj.disabled:
@@ -97,43 +95,6 @@ async def execute_graph_block(
return output
@v1_router.post(
path="/graphs",
tags=["graphs"],
status_code=201,
dependencies=[
Security(
require_permission(
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
)
)
],
)
async def create_graph(
graph: graph_db.Graph,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
),
) -> graph_db.GraphModel:
"""
Create a new agent graph.
The graph will be validated and assigned a new ID.
It is automatically added to the user's library.
"""
from backend.api.features.library import db as library_db
graph_model = graph_db.make_graph_model(graph, auth.user_id)
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
graph_model.validate_graph(for_run=False)
await graph_db.create_graph(graph_model, user_id=auth.user_id)
await library_db.create_library_agent(graph_model, auth.user_id)
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
return activated_graph
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
@@ -231,13 +192,13 @@ async def get_graph_execution_results(
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: store_db.StoreAgentsSortOptions | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
@@ -279,7 +240,7 @@ async def get_store_agents(
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
@@ -307,13 +268,13 @@ async def get_store_agent(
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
@@ -349,7 +310,7 @@ async def get_store_creators(
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorDetails,
)
async def get_store_creator(

View File

@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.copilot.model import ChatSession
from backend.copilot.tools import find_agent_tool, run_agent_tool
from backend.copilot.tools.models import ToolResponseBase
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
from backend.api.features.chat.tools.models import ToolResponseBase
from backend.data.auth.base import APIAuthorizationInfo
logger = logging.getLogger(__name__)

View File

@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=store_model.StoreListingsWithVersionsResponse,
)
async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
):
"""
Get store listings with their version history for admins.
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
page_size: Number of items per page
Returns:
Paginated listings with their versions
StoreListingsWithVersionsResponse with listings and their versions
"""
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
try:
listings = await store_db.get_admin_listings_with_versions(
status=status,
search_query=search,
page=page,
page_size=page_size,
)
return listings
except Exception as e:
logger.exception("Error getting admin listings with versions: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving listings with versions"
},
)
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=store_model.StoreSubmission,
)
async def review_submission(
store_listing_version_id: str,
request: store_model.ReviewSubmissionRequest,
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
) -> store_model.StoreSubmissionAdminView:
):
"""
Review a store listing submission.
@@ -73,24 +84,31 @@ async def review_submission(
user_id: Authenticated admin user performing the review
Returns:
StoreSubmissionAdminView with updated review information
StoreSubmission with updated review information
"""
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
try:
already_approved = await store_db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
submission = await store_db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
external_comments=request.comments,
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches whenever approval state changes, since store visibility can change
if state_changed:
store_cache.clear_all_caches()
return submission
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
@router.get(

View File

@@ -1,26 +1,20 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Any, Sequence, get_args, get_origin
from typing import Sequence
import prisma
from prisma.enums import ContentType
from prisma.models import mv_suggested_blocks
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
from backend.api.features.store.hybrid_search import unified_hybrid_search
import backend.data.block
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
BlockCategory,
BlockInfo,
BlockSchema,
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -28,7 +22,7 @@ from backend.util.models import Pagination
from .model import (
BlockCategoryResponse,
BlockResponse,
BlockTypeFilter,
BlockType,
CountResponse,
FilterType,
Provider,
@@ -43,16 +37,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
# Boost blocks over marketplace agents in search results
BLOCK_SCORE_BOOST = 50.0
# Block IDs to exclude from search results
EXCLUDED_BLOCK_IDS = frozenset(
{
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
}
)
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@@ -75,8 +59,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled and excluded blocks
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
@@ -104,7 +88,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
def get_blocks(
*,
category: str | None = None,
type: BlockTypeFilter | None = None,
type: BlockType | None = None,
provider: ProviderName | None = None,
page: int = 1,
page_size: int = 50,
@@ -127,9 +111,6 @@ def get_blocks(
# Skip disabled blocks
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
@@ -269,25 +250,14 @@ async def _build_cached_search_results(
"my_agents": 0,
}
# Use hybrid search when query is present, otherwise list all blocks
if (include_blocks or include_integrations) and normalized_query:
block_results, block_total, integration_total = await _hybrid_search_blocks(
query=search_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
elif include_blocks or include_integrations:
# No query - list all blocks using in-memory approach
block_results, block_total, integration_total = _collect_block_results(
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
@@ -332,14 +302,10 @@ async def _build_cached_search_results(
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Collect all blocks for listing (no search query).
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
@@ -352,10 +318,6 @@ def _collect_block_results(
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
block_info = block.get_info()
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
@@ -365,6 +327,10 @@ def _collect_block_results(
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
@@ -375,122 +341,8 @@ def _collect_block_results(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=BLOCK_SCORE_BOOST,
sort_key=block_info.name.lower(),
)
)
return results, block_count, integration_count
async def _hybrid_search_blocks(
*,
query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Search blocks using hybrid search with builder-specific filtering.
Uses unified_hybrid_search for semantic + lexical search, then applies
post-filtering for block/integration types and scoring adjustments.
Scoring:
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
to prioritize blocks over marketplace agents in combined results
- +30 for exact name match, +15 for prefix name match
- +20 if the block has an LlmModel field and the query matches an LLM model name
Args:
query: The search query string
include_blocks: Whether to include regular blocks
include_integrations: Whether to include integration blocks
Returns:
Tuple of (scored_items, block_count, integration_count)
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
normalized_query = query.strip().lower()
# Fetch more results to account for post-filtering
search_results, _ = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=150,
min_score=0.10,
)
# Load all blocks for getting BlockInfo
all_blocks = load_all_blocks()
for result in search_results:
block_id = result["content_id"]
# Skip excluded blocks
if block_id in EXCLUDED_BLOCK_IDS:
continue
metadata = result.get("metadata", {})
hybrid_score = result.get("relevance", 0.0)
# Get the actual block class
if block_id not in all_blocks:
continue
block_cls = all_blocks[block_id]
block: AnyBlockSchema = block_cls()
if block.disabled:
continue
# Check block/integration filter using metadata
is_integration = metadata.get("is_integration", False)
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
# Get block info
block_info = block.get_info()
# Calculate final score: scale hybrid score and add builder-specific bonuses
# Hybrid scores are 0-1, builder scores were 0-200+
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
# Add LLM model match bonus
has_llm_field = metadata.get("has_llm_model_field", False)
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
final_score += 20
# Add exact/prefix match bonus for deterministic tie-breaking
name = block_info.name.lower()
if name == normalized_query:
final_score += 30
elif name.startswith(normalized_query):
final_score += 15
# Track counts
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=final_score,
sort_key=name,
score=score,
sort_key=_get_item_name(block_info),
)
)
@@ -615,8 +467,6 @@ async def _get_static_counts():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
all_blocks += 1
@@ -643,25 +493,47 @@ async def _get_static_counts():
}
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if _contains_type(field.annotation, LlmModel):
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
@@ -768,32 +640,45 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers
@cached(ttl_seconds=3600, shared_cache=True)
@cached(ttl_seconds=3600)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
"""Return the most-executed blocks from the last 14 days.
suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
and returns the top `count` blocks sorted by execution count, excluding
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
"""
results = await mv_suggested_blocks.prisma().find_many()
results = await query_raw_with_schema(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM {schema_prefix}"AgentNodeExecution" execution
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
# Get the top blocks based on execution count
# But ignore Input, Output, Agent, and excluded blocks
# But ignore Input and Output blocks
blocks: list[tuple[BlockInfo, int]] = []
execution_counts = {row.block_id: row.execution_count for row in results}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
if block.disabled or block.block_type in (
BlockType.INPUT,
BlockType.OUTPUT,
BlockType.AGENT,
backend.data.block.BlockType.INPUT,
backend.data.block.BlockType.OUTPUT,
backend.data.block.BlockType.AGENT,
):
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
execution_count = execution_counts.get(block.id, 0)
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
blocks.append((block.get_info(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
import backend.api.features.library.model as library_model
import backend.api.features.store.model as store_model
from backend.blocks._base import BlockInfo
from backend.data.block import BlockInfo
from backend.integrations.providers import ProviderName
from backend.util.models import Pagination
@@ -15,7 +15,7 @@ FilterType = Literal[
"my_agents",
]
BlockTypeFilter = Literal["all", "input", "action", "output"]
BlockType = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel):
@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[SearchEntry]
providers: list[ProviderName]
top_blocks: list[BlockInfo]

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Sequence, cast, get_args
from typing import Annotated, Sequence
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
from . import db as builder_db
from . import model as builder_model
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
@@ -51,6 +49,11 @@ async def get_suggestions(
Get all suggestions for the Blocks Menu.
"""
return builder_model.SuggestionsResponse(
otto_suggestions=[
"What blocks do I need to get started?",
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=await builder_db.get_recent_searches(user_id),
providers=[
ProviderName.TWITTER,
@@ -85,7 +88,7 @@ async def get_block_categories(
)
async def get_blocks(
category: Annotated[str | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockTypeFilter | None, fastapi.Query()] = None,
type: Annotated[builder_model.BlockType | None, fastapi.Query()] = None,
provider: Annotated[ProviderName | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50,
@@ -148,7 +151,7 @@ async def get_providers(
async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
@@ -157,20 +160,9 @@ async def search(
"""
Search for blocks (including integrations), marketplace agents, and user library agents.
"""
# Parse and validate filter parameter
filters: list[builder_model.FilterType]
if filter:
filter_values = [f.strip() for f in filter.split(",")]
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
if invalid_filters:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
)
filters = cast(list[builder_model.FilterType], filter_values)
else:
filters = [
# If no filters are provided, then we will return all types
if not filter:
filter = [
"blocks",
"integrations",
"marketplace_agents",
@@ -182,7 +174,7 @@ async def search(
cached_results = await builder_db.get_sorted_search_results(
user_id=user_id,
search_query=search_query,
filters=filters,
filters=filter,
by_creator=by_creator,
)
@@ -204,7 +196,7 @@ async def search(
user_id,
builder_model.SearchEntry(
search_query=search_query,
filter=filters,
filter=filter,
by_creator=by_creator,
search_id=search_id,
),

View File

@@ -0,0 +1,96 @@
"""Configuration management for chat system."""
import os
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
description="Base URL for API (e.g., for OpenRouter)",
)
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
"""Get API key from environment if not provided."""
if v is None:
# Try to get from environment variables
# First check for CHAT_API_KEY (Pydantic prefix)
v = os.getenv("CHAT_API_KEY")
if not v:
# Fall back to OPEN_ROUTER_API_KEY
v = os.getenv("OPEN_ROUTER_API_KEY")
if not v:
# Fall back to OPENAI_API_KEY
v = os.getenv("OPENAI_API_KEY")
return v
@field_validator("base_url", mode="before")
@classmethod
def get_base_url(cls, v):
"""Get base URL from environment if not provided."""
if v is None:
# Check for OpenRouter or custom base URL
v = os.getenv("CHAT_BASE_URL")
if not v:
v = os.getenv("OPENROUTER_BASE_URL")
if not v:
v = os.getenv("OPENAI_BASE_URL")
if not v:
v = "https://openrouter.ai/api/v1"
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
"onboarding": "prompts/onboarding_system.md",
}
class Config:
"""Pydantic config."""
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore" # Ignore extra environment variables

View File

@@ -0,0 +1,291 @@
"""Database operations for chat sessions."""
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any, cast
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from backend.data.db import transaction
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": True},
)
if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
session_id: str,
user_id: str,
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
id=session_id,
userId=user_id,
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(
session_id: str,
credentials: dict[str, Any] | None = None,
successful_agent_runs: dict[str, Any] | None = None,
successful_agent_schedules: dict[str, Any] | None = None,
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
data["credentials"] = SafeJson(credentials)
if successful_agent_runs is not None:
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
if successful_agent_schedules is not None:
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
if total_prompt_tokens is not None:
data["totalPromptTokens"] = total_prompt_tokens
if total_completion_tokens is not None:
data["totalCompletionTokens"] = total_completion_tokens
if title is not None:
data["title"] = title
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": True},
)
if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
session_id: str,
role: str,
sequence: int,
content: str | None = None,
name: str | None = None,
tool_call_id: str | None = None,
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
"""Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
# We only include fields that have values, then cast at the end.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
# Add optional string fields
if content is not None:
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
# Add optional JSON fields only when they have values
if tool_calls is not None:
data["toolCalls"] = SafeJson(tool_calls)
if function_call is not None:
data["functionCall"] = SafeJson(function_call)
# Run message create and session timestamp update in parallel for lower latency
_, message = await asyncio.gather(
PrismaChatSession.prisma().update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return message
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
"""
if not messages:
return []
created_messages = []
async with transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
async def get_user_session_count(user_id: str) -> int:
"""Get the total number of chat sessions for a user."""
return await PrismaChatSession.prisma().count(where={"userId": user_id})
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
"""Delete a chat session and all its messages.
Args:
session_id: The session ID to delete.
user_id: If provided, validates that the session belongs to this user
before deletion. This prevents unauthorized deletion of other
users' sessions.
Returns:
True if deleted successfully, False otherwise.
"""
try:
# Build typed where clause with optional user_id validation
where_clause: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
where_clause["userId"] = user_id
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
if result == 0:
logger.warning(
f"No session deleted for {session_id} "
f"(user_id validation: {user_id is not None})"
)
return False
return True
except Exception as e:
logger.error(f"Failed to delete chat session {session_id}: {e}")
return False
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(
session_id: str,
tool_call_id: str,
new_content: str,
) -> bool:
"""Update the content of a tool message in chat history.
Used by background tasks to update pending operation messages with final results.
Args:
session_id: The chat session ID.
tool_call_id: The tool call ID to find the message.
new_content: The new content to set.
Returns:
True if a message was updated, False otherwise.
"""
try:
result = await PrismaChatMessage.prisma().update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={
"content": new_content,
},
)
if result == 0:
logger.warning(
f"No message found to update for session {session_id}, "
f"tool_call_id {tool_call_id}"
)
return False
return True
except Exception as e:
logger.error(
f"Failed to update tool message for session {session_id}, "
f"tool_call_id {tool_call_id}: {e}"
)
return False

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Any, Self, cast
from typing import Any
from weakref import WeakValueDictionary
from openai.types.chat import (
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# ===================== Chat data models ===================== #
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
class ChatMessage(BaseModel):
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
tool_calls: list[dict] | None = None
function_call: dict | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
return ChatMessage(
role=prisma_message.role,
content=prisma_message.content,
name=prisma_message.name,
tool_call_id=prisma_message.toolCallId,
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
)
class Usage(BaseModel):
prompt_tokens: int
@@ -75,10 +92,11 @@ class Usage(BaseModel):
total_tokens: int
class ChatSessionInfo(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str
title: str | None = None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime
@@ -86,9 +104,40 @@ class ChatSessionInfo(BaseModel):
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
@staticmethod
def new(user_id: str) -> "ChatSession":
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
# Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field(
@@ -110,10 +159,11 @@ class ChatSessionInfo(BaseModel):
)
)
return cls(
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
@@ -122,56 +172,6 @@ class ChatSessionInfo(BaseModel):
successful_agent_schedules=successful_agent_schedules,
)
class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
if prisma_session.Messages is None:
raise ValueError(
f"Prisma session {prisma_session.id} is missing Messages relation"
)
return cls(
**ChatSessionInfo.from_db(prisma_session).model_dump(),
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
)
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
for message in self.messages:
@@ -258,72 +258,43 @@ class ChatSession(ChatSessionInfo):
name=message.name or "",
)
)
return self._merge_consecutive_assistant_messages(messages)
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
return messages
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
# ================ Chat cache + DB operations ================ #
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
# connected directly.
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session in Redis (without persisting to the database)."""
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session without persisting to the database."""
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
@@ -339,6 +310,80 @@ async def invalidate_session_cache(session_id: str) -> None:
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None = None,
@@ -370,7 +415,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database")
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -386,7 +431,7 @@ async def get_chat_session(
# Cache the session from DB
try:
await cache_chat_session(session)
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -394,44 +439,6 @@ async def get_chat_session(
return session
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
session = await chat_db().get_chat_session(session_id)
if not session:
return None
logger.info(
f"Loaded session {session_id} from DB: "
f"has_messages={bool(session.messages)}, "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
@@ -451,35 +458,25 @@ async def upsert_chat_session(
lock = await _get_session_lock(session.session_id)
async with lock:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
)
db_error: Exception | None = None
# Save to database (primary storage)
try:
await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
)
await _save_session_to_db(session, existing_message_count)
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"
)
db_error = e
# Save to cache (best-effort, even if DB failed).
# Title updates (update_session_title) run *outside* this lock because
# they only touch the title field, not messages. So a concurrent rename
# or auto-title may have written a newer title to Redis while this
# upsert was in progress. Always prefer the cached title to avoid
# overwriting it with the stale in-memory copy.
# Save to cache (best-effort, even if DB failed)
try:
existing_cached = await _get_session_from_cache(session.session_id)
if existing_cached and existing_cached.title:
session = session.model_copy(update={"title": existing_cached.title})
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
if db_error is None:
@@ -500,107 +497,6 @@ async def upsert_chat_session(
return session
async def _save_session_to_db(
session: ChatSession,
existing_message_count: int,
*,
skip_existence_check: bool = False,
) -> None:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
"""
db = chat_db()
if not skip_existence_check:
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -613,7 +509,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=user_id,
)
@@ -625,7 +521,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Cache the session (best-effort optimization, DB is source of truth)
try:
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
@@ -636,16 +532,20 @@ async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> tuple[list[ChatSessionInfo], int]:
) -> tuple[list[ChatSession], int]:
"""Get chat sessions for a user from the database with total count.
Returns:
A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page).
"""
db = chat_db()
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
total_count = await db.get_user_session_count(user_id)
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
total_count = await chat_db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
return sessions, total_count
@@ -663,7 +563,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
"""
# Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache
deleted = await chat_db().delete_chat_session(session_id, user_id)
deleted = await chat_db.delete_chat_session(session_id, user_id)
if not deleted:
return False
@@ -680,89 +580,38 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
try:
from .tools.agent_browser import close_browser_session
await close_browser_session(session_id, user_id=user_id)
except Exception as e:
logger.debug(f"Browser cleanup for session {session_id}: {e}")
return True
async def update_session_title(
session_id: str,
user_id: str,
title: str,
*,
only_if_empty: bool = False,
) -> bool:
"""Update the title of a chat session, scoped to the owning user.
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
Lightweight operation that doesn't touch messages, avoiding race conditions
with concurrent message updates.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Args:
session_id: The session ID to update.
user_id: Owning user the DB query filters on this.
title: The new title to set.
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
so auto-generated titles never overwrite a user-set title.
Returns:
True if updated successfully, False otherwise (not found, wrong user,
or when only_if_empty title was already set).
True if updated successfully, False otherwise.
"""
try:
updated = await chat_db().update_chat_session_title(
session_id, user_id, title, only_if_empty=only_if_empty
)
if not updated:
result = await chat_db.update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
# Invalidate cache so next fetch gets updated title
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
logger.warning(
f"Cache title update failed for session {session_id} (non-critical): {e}"
)
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}")
return False
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -0,0 +1,119 @@
import pytest
from .model import (
ChatMessage,
ChatSession,
Usage,
get_chat_session,
upsert_chat_session,
)
messages = [
ChatMessage(content="Hello, how are you?", role="user"),
ChatMessage(
content="I'm fine, thank you!",
role="assistant",
tool_calls=[
{
"id": "t123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "New York"}',
},
}
],
),
ChatMessage(
content="I'm using the tool to get the weather",
role="tool",
tool_call_id="t123",
),
]
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()
s2 = ChatSession.model_validate_json(serialized)
assert s2.model_dump() == s.model_dump()
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 == s
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage(setup_test_user, test_user_id):
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -5,18 +5,11 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
import json
import logging
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
from backend.util.truncate import truncate
logger = logging.getLogger(__name__)
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -25,10 +18,6 @@ class ResponseType(str, Enum):
START = "start"
FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
@@ -52,8 +41,7 @@ class StreamBaseResponse(BaseModel):
def to_sse(self) -> str:
"""Convert to SSE format."""
json_str = self.model_dump_json(exclude_none=True)
return f"data: {json_str}\n\n"
return f"data: {self.model_dump_json()}\n\n"
# ========== Message Lifecycle ==========
@@ -64,18 +52,6 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
sessionId: str | None = Field(
default=None,
description="Session ID for SSE reconnection.",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse):
@@ -84,26 +60,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ==========
@@ -151,16 +107,13 @@ class StreamToolInputAvailable(StreamBaseResponse):
)
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
class StreamToolOutputAvailable(StreamBaseResponse):
"""Tool execution result."""
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
@@ -168,19 +121,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def model_post_init(self, __context: Any) -> None:
"""Truncate oversized outputs after construction."""
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ==========
@@ -204,18 +144,6 @@ class StreamError(StreamBaseResponse):
default=None, description="Additional error details"
)
def to_sse(self) -> str:
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
The AI SDK uses z.strictObject({type, errorText}) which rejects
any extra fields like `code` or `details`.
"""
data = {
"type": self.type.value,
"errorText": self.errorText,
}
return f"data: {json_dumps(data)}\n\n"
class StreamHeartbeat(StreamBaseResponse):
"""Heartbeat to keep SSE connection alive during long-running operations.

View File

@@ -1,66 +1,22 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
delete_chat_session,
get_chat_session,
get_user_sessions,
update_session_title,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockDetailsResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
MCPToolOutputResponse,
MCPToolsDiscoveredResponse,
NeedLoginResponse,
NoResultsResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
@@ -89,9 +45,6 @@ class StreamChatRequest(BaseModel):
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
class CreateSessionResponse(BaseModel):
@@ -102,13 +55,6 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
turn_id: str
last_message_id: str # Redis Stream message ID for resumption
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -117,7 +63,6 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -136,27 +81,6 @@ class ListSessionsResponse(BaseModel):
total: int
class CancelSessionResponse(BaseModel):
"""Response model for the cancel session endpoint."""
cancelled: bool
reason: str | None = None
class UpdateSessionTitleRequest(BaseModel):
"""Request model for updating a session's title."""
title: str
@field_validator("title")
@classmethod
def title_must_not_be_blank(cls, v: str) -> str:
stripped = v.strip()
if not stripped:
raise ValueError("Title must not be blank")
return stripped
# ========== Routes ==========
@@ -231,92 +155,6 @@ async def create_session(
)
@router.delete(
"/sessions/{session_id}",
dependencies=[Security(auth.requires_user)],
status_code=204,
responses={404: {"description": "Session not found or access denied"}},
)
async def delete_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""
Delete a chat session.
Permanently removes a chat session and all its messages.
Only the owner can delete their sessions.
Args:
session_id: The session ID to delete.
user_id: The authenticated user's ID.
Returns:
204 No Content on success.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
deleted = await delete_chat_session(session_id, user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
# Best-effort cleanup of the E2B sandbox (if any).
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
e2b_cfg = ChatConfig()
if e2b_cfg.e2b_active:
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
try:
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
)
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
dependencies=[Security(auth.requires_user)],
status_code=200,
responses={404: {"description": "Session not found or access denied"}},
)
async def update_session_title_route(
session_id: str,
request: UpdateSessionTitleRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> dict:
"""
Update the title of a chat session.
Allows the user to rename their chat session.
Args:
session_id: The session ID to update.
request: Request body containing the new title.
user_id: The authenticated user's ID.
Returns:
dict: Status of the update.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
success = await update_session_title(session_id, user_id, request.title)
if not success:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return {"status": "ok"}
@router.get(
"/sessions/{session_id}",
)
@@ -328,14 +166,13 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
SessionDetailResponse: Details for the requested session, or None if not found.
"""
session = await get_chat_session(session_id, user_id)
@@ -343,25 +180,11 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
return SessionDetailResponse(
id=session.session_id,
@@ -369,55 +192,9 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@router.post(
"/sessions/{session_id}/cancel",
status_code=200,
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
polls Redis until the task status flips from ``running`` or a timeout
(5 s) is reached. Returns only after the cancellation is confirmed.
"""
await _validate_and_get_session(session_id, user_id)
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
if not active_session:
return CancelSessionResponse(cancelled=True, reason="no_active_session")
await enqueue_cancel_task(session_id)
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
# Poll until the executor confirms the task is no longer running.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
session_state = await stream_registry.get_session(session_id)
if session_state is None or session_state.status != "running":
logger.info(
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
)
return CancelSessionResponse(cancelled=True)
logger.warning(
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
)
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
return CancelSessionResponse(cancelled=True)
@router.post(
"/sessions/{session_id}/stream",
)
@@ -434,10 +211,6 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
@@ -446,246 +219,41 @@ async def stream_chat_post(
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
import time
session = await _validate_and_get_session(session_id, user_id)
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
}
},
)
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id=subscribe_from_id,
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {session_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n"
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -702,94 +270,63 @@ async def stream_chat_post(
@router.get(
"/sessions/{session_id}/stream",
)
async def resume_session_stream(
async def stream_chat_get(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Resume an active stream for a session.
Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load.
Checks for an active (in-progress) task on the session and either replays
the full SSE stream or returns 204 No Content if nothing is running.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier.
session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
if not active_session:
return Response(status_code=204)
# Always replay from the beginning ("0-0") on resume.
# We can't use last_message_id because it's the latest ID in the backend
# stream, not the latest the frontend received — the gap causes lost
# messages. The frontend deduplicates replayed content.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id="0-0",
)
if subscriber_queue is None:
return Response(status_code=204)
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
exc_info=True,
)
logger.info(
"Resume stream completed",
extra={
"session_id": session_id,
"n_chunks": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -797,8 +334,8 @@ async def resume_session_stream(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -806,6 +343,7 @@ async def resume_session_stream(
@router.patch(
"/sessions/{session_id}/assign-user",
dependencies=[Security(auth.requires_user)],
status_code=200,
)
async def session_assign_user(
session_id: str,
@@ -828,26 +366,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ==========
@@ -884,43 +402,3 @@ async def health_check() -> dict:
"service": "chat",
"version": "0.1.0",
}
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| SuggestedGoalResponse
| BlockListResponse
| BlockDetailsResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| MCPToolsDiscoveredResponse
| MCPToolOutputResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -1,251 +0,0 @@
"""Tests for chat API routes: session title update and file attachment validation."""
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _mock_update_session_title(
mocker: pytest_mock.MockerFixture, *, success: bool = True
):
"""Mock update_session_title."""
return mocker.patch(
"backend.api.features.chat.routes.update_session_title",
new_callable=AsyncMock,
return_value=success,
)
# ─── Update title: success ─────────────────────────────────────────────
def test_update_title_success(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "My project"},
)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
def test_update_title_trims_whitespace(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_update = _mock_update_session_title(mocker, success=True)
response = client.patch(
"/sessions/sess-1/title",
json={"title": " trimmed "},
)
assert response.status_code == 200
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
def test_update_title_blank_rejected(
test_user_id: str,
) -> None:
"""Whitespace-only titles must be rejected before hitting the DB."""
response = client.patch(
"/sessions/sess-1/title",
json={"title": " "},
)
assert response.status_code == 422
def test_update_title_empty_rejected(
test_user_id: str,
) -> None:
response = client.patch(
"/sessions/sess-1/title",
json={"title": ""},
)
assert response.status_code == 422
# ─── Update title: session not found or wrong user → 404 ──────────────
def test_update_title_not_found(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
_mock_update_session_title(mocker, success=False)
response = client.patch(
"/sessions/sess-1/title",
json={"title": "New name"},
)
assert response.status_code == 404
# ─── file_ids Pydantic validation ─────────────────────────────────────
def test_stream_chat_rejects_too_many_file_ids():
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
},
)
# Should get past validation — 200 streaming response expected
assert response.status_code == 200
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [
valid_id,
"not-a-uuid",
"../../../etc/passwd",
"",
],
},
)
# The find_many call should only receive the one valid UUID
mock_prisma.find_many.assert_called_once()
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "file_ids": [fid]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
import logging
from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"

View File

@@ -1,42 +1,23 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.tracking import track_tool_called
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .fix_agent import FixAgentGraphTool
from .get_agent_building_guide import GetAgentBuildingGuideTool
from .get_doc_page import GetDocPageTool
from .get_mcp_guide import GetMCPGuideTool
from .manage_folders import (
CreateFolderTool,
DeleteFolderTool,
ListFoldersTool,
MoveAgentsToFolderTool,
MoveFolderTool,
UpdateFolderTool,
)
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
@@ -45,8 +26,7 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.api.features.chat.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -54,41 +34,15 @@ logger = logging.getLogger(__name__)
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
"find_library_agent": FindLibraryAgentTool(),
# Folder management tools
"create_folder": CreateFolderTool(),
"list_folders": ListFoldersTool(),
"update_folder": UpdateFolderTool(),
"move_folder": MoveFolderTool(),
"delete_folder": DeleteFolderTool(),
"move_agents_to_folder": MoveAgentsToFolderTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
"get_agent_building_guide": GetAgentBuildingGuideTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
"browser_navigate": BrowserNavigateTool(),
"browser_act": BrowserActTool(),
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),
# Agent generation tools (local validation/fixing)
"validate_agent_graph": ValidateAgentGraphTool(),
"fix_agent_graph": FixAgentGraphTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
@@ -100,17 +54,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
def get_available_tools() -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
"""
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
]
# Generated from registry for OpenAI API
tools: list[ChatCompletionToolParam] = [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
]
def get_tool(tool_name: str) -> BaseTool | None:

View File

@@ -1,46 +1,22 @@
import logging
import uuid
from datetime import UTC, datetime
from os import getenv
import pytest
import pytest_asyncio
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data import db as db_module
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials
from backend.data.user import get_or_create_user
from backend.integrations.credentials_store import IntegrationCredentialsStore
_logger = logging.getLogger(__name__)
async def _ensure_db_connected() -> None:
"""Ensure the Prisma connection is alive on the current event loop.
On Python 3.11, the httpx transport inside Prisma can reference a stale
(closed) event loop when session-scoped async fixtures are evaluated long
after the initial ``server`` fixture connected Prisma. A cheap health-check
followed by a reconnect fixes this without affecting other fixtures.
"""
try:
await prisma.query_raw("SELECT 1")
except Exception:
_logger.info("Prisma connection stale reconnecting")
try:
await db_module.disconnect()
except Exception:
pass
await db_module.connect()
def make_session(user_id: str):
return ChatSession(
@@ -55,19 +31,15 @@ def make_session(user_id: str):
)
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_test_data(server):
@pytest.fixture(scope="session")
async def setup_test_data():
"""
Set up test data for run_agent tests:
1. Create a test user
2. Create a test graph (agent input -> agent output)
3. Create a store listing and store listing version
4. Approve the store listing version
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
# 1. Create a test user
user_data = {
"sub": f"test-user-{uuid.uuid4()}",
@@ -151,8 +123,8 @@ async def setup_test_data(server):
unique_slug = f"test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
graph_id=created_graph.id,
graph_version=created_graph.version,
agent_id=created_graph.id,
agent_version=created_graph.version,
slug=unique_slug,
name="Test Agent",
description="A simple test agent",
@@ -161,10 +133,10 @@ async def setup_test_data(server):
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.listing_version_id is not None
assert store_submission.store_listing_version_id is not None
# 4. Approve the store listing version
await store_db.review_store_submission(
store_listing_version_id=store_submission.listing_version_id,
store_listing_version_id=store_submission.store_listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval",
@@ -178,19 +150,15 @@ async def setup_test_data(server):
}
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_llm_test_data(server):
@pytest.fixture(scope="session")
async def setup_llm_test_data():
"""
Set up test data for LLM agent tests:
1. Create a test user
2. Create test OpenAI credentials for the user
3. Create a test graph with input -> LLM block -> output
4. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
key = getenv("OPENAI_API_KEY")
if not key:
return pytest.skip("OPENAI_API_KEY is not set")
@@ -321,8 +289,8 @@ async def setup_llm_test_data(server):
unique_slug = f"llm-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
graph_id=created_graph.id,
graph_version=created_graph.version,
agent_id=created_graph.id,
agent_version=created_graph.version,
slug=unique_slug,
name="LLM Test Agent",
description="An agent with LLM capabilities",
@@ -330,9 +298,9 @@ async def setup_llm_test_data(server):
categories=["testing", "ai"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.listing_version_id is not None
assert store_submission.store_listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.listing_version_id,
store_listing_version_id=store_submission.store_listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for LLM agent",
@@ -347,18 +315,14 @@ async def setup_llm_test_data(server):
}
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_firecrawl_test_data(server):
@pytest.fixture(scope="session")
async def setup_firecrawl_test_data():
"""
Set up test data for Firecrawl agent tests (missing credentials scenario):
1. Create a test user (WITHOUT Firecrawl credentials)
2. Create a test graph with input -> Firecrawl block -> output
3. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
# 1. Create a test user
user_data = {
"sub": f"test-user-{uuid.uuid4()}",
@@ -476,8 +440,8 @@ async def setup_firecrawl_test_data(server):
unique_slug = f"firecrawl-test-agent-{str(uuid.uuid4())[:8]}"
store_submission = await store_db.create_store_submission(
user_id=user.id,
graph_id=created_graph.id,
graph_version=created_graph.version,
agent_id=created_graph.id,
agent_version=created_graph.version,
slug=unique_slug,
name="Firecrawl Test Agent",
description="An agent with Firecrawl integration (no credentials)",
@@ -485,9 +449,9 @@ async def setup_firecrawl_test_data(server):
categories=["testing", "scraping"],
image_urls=["https://example.com/image.jpg"],
)
assert store_submission.listing_version_id is not None
assert store_submission.store_listing_version_id is not None
await store_db.review_store_submission(
store_listing_version_id=store_submission.listing_version_id,
store_listing_version_id=store_submission.store_listing_version_id,
is_approved=True,
external_comments="Approved for testing",
internal_comments="Test approval for Firecrawl agent",

View File

@@ -3,9 +3,11 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.data.db_accessors import understanding_db
from backend.data.understanding import BusinessUnderstandingInput
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
]
# Upsert with merge
understanding = await understanding_db().upsert_business_understanding(
user_id, input_data
)
understanding = await upsert_business_understanding(user_id, input_data)
# Build current understanding summary (filter out empty values)
current_understanding = {

View File

@@ -0,0 +1,31 @@
"""Agent generator package - Creates agents from natural language."""
from .core import (
AgentGeneratorNotConfiguredError,
decompose_goal,
generate_agent,
generate_agent_patch,
get_agent_as_json,
json_to_graph,
save_agent_to_library,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
# Core functions
"decompose_goal",
"generate_agent",
"generate_agent_patch",
"save_agent_to_library",
"get_agent_as_json",
"json_to_graph",
# Exceptions
"AgentGeneratorNotConfiguredError",
# Service
"is_external_service_configured",
"check_external_service_health",
# Error handling
"get_user_message_for_error",
]

View File

@@ -0,0 +1,281 @@
"""Core agent generation functions."""
import logging
import uuid
from typing import Any
from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node, create_graph
from .service import (
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
is_external_service_configured,
)
logger = logging.getLogger(__name__)
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
pass
def _check_service_configured() -> None:
"""Check if the external Agent Generator service is configured.
Raises:
AgentGeneratorNotConfiguredError: If the service is not configured.
"""
if not is_external_service_configured():
raise AgentGeneratorNotConfiguredError(
"Agent Generator service is not configured. "
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
)
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
return await decompose_goal_external(description, context)
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(instructions)
if result:
# Check if it's an error response - pass through as-is
if isinstance(result, dict) and result.get("type") == "error":
return result
# Ensure required fields for successful agent generation
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
result["version"] = 1
if "is_active" not in result:
result["is_active"] = True
return result
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
"""Convert agent JSON dict to Graph model.
Args:
agent_json: Agent JSON with nodes and links
Returns:
Graph ready for saving
"""
nodes = []
for n in agent_json.get("nodes", []):
node = Node(
id=n.get("id", str(uuid.uuid4())),
block_id=n["block_id"],
input_default=n.get("input_default", {}),
metadata=n.get("metadata", {}),
)
nodes.append(node)
links = []
for link_data in agent_json.get("links", []):
link = Link(
id=link_data.get("id", str(uuid.uuid4())),
source_id=link_data["source_id"],
sink_id=link_data["sink_id"],
source_name=link_data["source_name"],
sink_name=link_data["sink_name"],
is_static=link_data.get("is_static", False),
)
links.append(link)
return Graph(
id=agent_json.get("id", str(uuid.uuid4())),
version=agent_json.get("version", 1),
is_active=agent_json.get("is_active", True),
name=agent_json.get("name", "Generated Agent"),
description=agent_json.get("description", ""),
nodes=nodes,
links=links,
)
def _reassign_node_ids(graph: Graph) -> None:
"""Reassign all node and link IDs to new UUIDs.
This is needed when creating a new version to avoid unique constraint violations.
"""
# Create mapping from old node IDs to new UUIDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
# Reassign node IDs
for node in graph.nodes:
node.id = id_map[node.id]
# Update link references to use new node IDs
for link in graph.links:
link.id = str(uuid.uuid4()) # Also give links new IDs
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
"""Save agent to database and user's library.
Args:
agent_json: Agent JSON dict
user_id: User ID
is_update: Whether this is an update to an existing agent
Returns:
Tuple of (created Graph, LibraryAgent)
"""
from backend.data.graph import get_graph_all_versions
graph = json_to_graph(agent_json)
if is_update:
# For updates, keep the same graph ID but increment version
# and reassign node/link IDs to avoid conflicts
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
# Reassign node IDs (but keep graph ID the same)
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
# For new agents, always generate a fresh UUID to avoid collisions
graph.id = str(uuid.uuid4())
graph.version = 1
# Reassign all node IDs as well
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
# Save to database
created_graph = await create_graph(graph, user_id)
# Add to user's library (or update existing library agent)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
sensitive_action_safe_mode=True,
create_library_agents_for_sub_graphs=False,
)
return created_graph, library_agents[0]
async def get_agent_as_json(
graph_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
graph_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
from backend.data.graph import get_graph
# Try to get the graph (version=None gets the active version)
graph = await get_graph(graph_id, version=None, user_id=user_id)
if not graph:
return None
# Convert to JSON format
nodes = []
for node in graph.nodes:
nodes.append(
{
"id": node.id,
"block_id": node.block_id,
"input_default": node.input_default,
"metadata": node.metadata,
}
)
links = []
for node in graph.nodes:
for link in node.output_links:
links.append(
{
"id": link.id,
"source_id": link.source_id,
"sink_id": link.sink_id,
"source_name": link.source_name,
"sink_name": link.sink_name,
"is_static": link.is_static,
}
)
return {
"id": graph.id,
"name": graph.name,
"description": graph.description,
"version": graph.version,
"is_active": graph.is_active,
"nodes": nodes,
"links": links,
}
async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
The external Agent Generator service handles:
- Generating the patch
- Applying the patch
- Fixing and validating the result
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(update_request, current_agent)

View File

@@ -0,0 +1,43 @@
"""Error handling utilities for agent generator."""
def get_user_message_for_error(
error_type: str,
operation: str = "process the request",
llm_parse_message: str | None = None,
validation_message: str | None = None,
) -> str:
"""Get a user-friendly error message based on error type.
This function maps internal error types to user-friendly messages,
providing a consistent experience across different agent operations.
Args:
error_type: The error type from the external service
(e.g., "llm_parse_error", "timeout", "rate_limit")
operation: Description of what operation failed, used in the default
message (e.g., "analyze the goal", "generate the agent")
llm_parse_message: Custom message for llm_parse_error type
validation_message: Custom message for validation_error type
Returns:
User-friendly error message suitable for display to the user
"""
if error_type == "llm_parse_error":
return (
llm_parse_message
or "The AI had trouble processing this request. Please try again."
)
elif error_type == "validation_error":
return (
validation_message
or "The request failed validation. Please try rephrasing."
)
elif error_type == "patch_error":
return "Failed to apply the changes. Please try a different approach."
elif error_type in ("timeout", "llm_timeout"):
return "The request took too long. Please try again."
elif error_type in ("rate_limit", "llm_rate_limit"):
return "The service is currently busy. Please try again in a moment."
else:
return f"Failed to {operation}. Please try again."

View File

@@ -0,0 +1,374 @@
"""External Agent Generator service client.
This module provides a client for communicating with the external Agent Generator
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
will delegate to the external service instead of using the built-in LLM-based implementation.
"""
import logging
from typing import Any
import httpx
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
def _get_settings() -> Settings:
"""Get or create settings singleton."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def is_external_service_configured() -> bool:
"""Check if external Agent Generator service is configured."""
settings = _get_settings()
return bool(settings.config.agentgenerator_host)
def _get_base_url() -> str:
"""Get the base URL for the external service."""
settings = _get_settings()
host = settings.config.agentgenerator_host
port = settings.config.agentgenerator_port
return f"http://{host}:{port}"
def _get_client() -> httpx.AsyncClient:
"""Get or create the HTTP client for the external service."""
global _client
if _client is None:
settings = _get_settings()
_client = httpx.AsyncClient(
base_url=_get_base_url(),
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
)
return _client
async def decompose_goal_external(
description: str, context: str = ""
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
Returns:
Dict with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
client = _get_client()
# Build the request payload
payload: dict[str, Any] = {"description": description}
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
try:
response = await client.post("/api/decompose-description", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
if response_type == "instructions":
return {"type": "instructions", "steps": data.get("steps", [])}
elif response_type == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
elif response_type == "unachievable_goal":
return {
"type": "unachievable_goal",
"reason": data.get("reason"),
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "vague_goal":
return {
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any],
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
Returns:
Agent JSON dict on success, or error dict {"type": "error", ...} on error
"""
client = _get_client()
try:
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str, current_agent: dict[str, Any]
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
Returns:
Updated agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
try:
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service.
Returns:
List of block info dicts or None on error
"""
client = _get_client()
try:
response = await client.get("/api/blocks")
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error("External service returned error getting blocks")
return None
return data.get("blocks", [])
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error getting blocks from external service: {e}")
return None
except httpx.RequestError as e:
logger.error(f"Request error getting blocks from external service: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error getting blocks from external service: {e}")
return None
async def health_check() -> bool:
"""Check if the external service is healthy.
Returns:
True if healthy, False otherwise
"""
if not is_external_service_configured():
return False
client = _get_client()
try:
response = await client.get("/health")
response.raise_for_status()
data = response.json()
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
except Exception as e:
logger.warning(f"External agent generator health check failed: {e}")
return False
async def close_client() -> None:
"""Close the HTTP client."""
global _client
if _client is not None:
await _client.aclose()
_client = None

View File

@@ -5,15 +5,15 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
from backend.data.db_accessors import execution_db, library_db
from backend.data import execution as execution_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -34,7 +34,6 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -118,11 +117,6 @@ class AgentOutputTool(BaseTool):
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
@property
@@ -152,13 +146,6 @@ class AgentOutputTool(BaseTool):
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
},
},
"required": [],
}
@@ -178,12 +165,10 @@ class AgentOutputTool(BaseTool):
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
lib_db = library_db()
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await lib_db.get_library_agent(library_agent_id, user_id)
agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
@@ -197,7 +182,7 @@ class AgentOutputTool(BaseTool):
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
@@ -209,7 +194,7 @@ class AgentOutputTool(BaseTool):
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await lib_db.list_library_agents(
response = await library_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
@@ -238,20 +223,14 @@ class AgentOutputTool(BaseTool):
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
"""
exec_db = execution_db()
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await exec_db.get_graph_execution(
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
@@ -260,25 +239,11 @@ class AgentOutputTool(BaseTool):
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Determine which statuses to query
statuses = [ExecutionStatus.COMPLETED]
if include_running:
statuses.extend(
[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
]
)
# Get executions with time filters
executions = await exec_db.get_graph_executions(
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=statuses,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
@@ -289,7 +254,7 @@ class AgentOutputTool(BaseTool):
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await exec_db.get_graph_execution(
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -297,7 +262,7 @@ class AgentOutputTool(BaseTool):
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await exec_db.get_graph_execution(
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -345,33 +310,10 @@ class AgentOutputTool(BaseTool):
for e in available_executions[:5]
]
# Build appropriate message based on execution status
if execution.status == ExecutionStatus.COMPLETED:
message = f"Found execution outputs for agent '{agent.name}'"
elif execution.status == ExecutionStatus.FAILED:
message = f"Execution for agent '{agent.name}' failed"
elif execution.status == ExecutionStatus.TERMINATED:
message = f"Execution for agent '{agent.name}' was terminated"
elif execution.status == ExecutionStatus.REVIEW:
message = (
f"Execution for agent '{agent.name}' is awaiting human review. "
"The user needs to approve it before it can continue."
)
elif execution.status in (
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
):
message = (
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
"Results may be incomplete. Use wait_if_running to wait for completion."
)
else:
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f" Showing latest of {len(available_executions)} matching executions."
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
@@ -438,7 +380,7 @@ class AgentOutputTool(BaseTool):
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db().get_graph_execution(
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
@@ -450,7 +392,7 @@ class AgentOutputTool(BaseTool):
)
# Find library agent by graph_id
agent = await library_db().get_library_agent_by_graph_id(
agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:
@@ -486,17 +428,13 @@ class AgentOutputTool(BaseTool):
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Check if we should wait for running executions
wait_timeout = input_data.wait_if_running
# Fetch execution(s) - include running if we're going to wait
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
)
if exec_error:
@@ -505,17 +443,4 @@ class AgentOutputTool(BaseTool):
session_id=session_id,
)
# If we have an execution that's still running and we should wait
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
logger.info(
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_timeout,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -0,0 +1,151 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
from typing import Literal
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
AgentInfo,
AgentsFoundResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
Args:
query: Search query string
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
if not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
if source == "library" and not user_id:
return ErrorResponse(
message="User authentication required to search library",
session_id=session_id,
)
agents: list[AgentInfo] = []
try:
if source == "marketplace":
logger.info(f"Searching marketplace for: {query}")
results = await store_db.get_store_agents(search_query=query, page_size=5)
for agent in results.agents:
agents.append(
AgentInfo(
id=f"{agent.creator}/{agent.slug}",
name=agent.agent_name,
description=agent.description or "",
source="marketplace",
in_library=False,
creator=agent.creator,
category="general",
rating=agent.rating,
runs=agent.runs,
is_featured=False,
)
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
except DatabaseError as e:
logger.error(f"Error searching {source}: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to search {source}. Please try again.",
error=str(e),
session_id=session_id,
)
if not agents:
suggestions = (
[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
if source == "marketplace"
else [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
)
no_results_msg = (
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
if source == "marketplace"
else f"No agents matching '{query}' found in your library."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
)
return AgentsFoundResponse(
message=message,
title=title,
agents=agents,
count=len(agents),
session_id=session_id,
)

View File

@@ -0,0 +1,129 @@
"""Base classes and shared utilities for chat tools."""
import logging
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class BaseTool:
"""Base class for all chat tools."""
@property
def name(self) -> str:
"""Tool name for OpenAI function calling."""
raise NotImplementedError
@property
def description(self) -> str:
"""Tool description for OpenAI."""
raise NotImplementedError
@property
def parameters(self) -> dict[str, Any]:
"""Tool parameters schema for OpenAI."""
raise NotImplementedError
@property
def requires_auth(self) -> bool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(
type="function",
function={
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
)
async def execute(
self,
user_id: str | None,
session: ChatSession,
tool_call_id: str,
**kwargs,
) -> StreamToolOutputAvailable:
"""Execute the tool with authentication check.
Args:
user_id: User ID (may be anonymous like "anon_123")
session_id: Chat session ID
**kwargs: Tool-specific parameters
Returns:
Pydantic response object
"""
if self.requires_auth and not user_id:
logger.error(
f"Attempted tool call for {self.name} but user not authenticated"
)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=NeedLoginResponse(
message=f"Please sign in to use {self.name}",
session_id=session.session_id,
).model_dump_json(),
success=False,
)
try:
result = await self._execute(user_id, session, **kwargs)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=result.model_dump_json(),
)
except Exception as e:
logger.error(f"Error in {self.name}: {e}", exc_info=True)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=ErrorResponse(
message=f"An error occurred while executing {self.name}",
error=str(e),
session_id=session.session_id,
).model_dump_json(),
success=False,
)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Internal execution logic to be implemented by subclasses.
Args:
user_id: User ID (authenticated or anonymous)
session_id: Chat session ID
**kwargs: Tool-specific parameters
Returns:
Pydantic response object
"""
raise NotImplementedError

View File

@@ -0,0 +1,283 @@
"""CreateAgentTool - Creates agents from natural language descriptions."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
generate_agent,
get_user_message_for_error,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
@property
def name(self) -> str:
return "create_agent"
@property
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"description": {
"type": "string",
"description": (
"Natural language description of what the agent should do. "
"Be specific about inputs, outputs, and the workflow steps."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions. "
"Include any preferences or constraints mentioned by the user."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["description"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the create_agent tool.
Flow:
1. Decompose the description into steps (may return clarifying questions)
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
session_id=session_id,
)
# Step 1: Decompose goal into steps
try:
decomposition_result = await decompose_goal(description, context)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
session_id=session_id,
)
# Check if the result is an error from the external service
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if LLM returned clarifying questions
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information to create this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Check for unachievable/vague goals
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
return ErrorResponse(
message=(
f"This goal cannot be accomplished with the available blocks. "
f"{reason} "
f"Suggestion: {suggested}"
),
error="unachievable_goal",
details={"suggested_goal": suggested, "reason": reason},
session_id=session_id,
)
if decomposition_result.get("type") == "vague_goal":
suggested = decomposition_result.get("suggested_goal", "")
return ErrorResponse(
message=(
f"The goal is too vague to create a specific workflow. "
f"Suggestion: {suggested}"
),
error="vague_goal",
details={"suggested_goal": suggested},
session_id=session_id,
)
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(decomposition_result)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent generation is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
f"Review it and call create_agent with save=true to save it to your library."
),
agent_json=agent_json,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
# Save to library
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -0,0 +1,249 @@
"""EditAgentTool - Edits existing agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_user_message_for_error,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
@property
def name(self) -> str:
return "edit_agent"
@property
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The ID of the agent to edit. "
"Can be a graph ID or library agent ID."
),
},
"changes": {
"type": "string",
"description": (
"Natural language description of what changes to make. "
"Be specific about what to add, remove, or modify."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the changes. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "changes"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
Flow:
1. Fetch the current agent
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
session_id=session_id,
)
if not changes:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
session_id=session_id,
)
# Step 1: Fetch current agent
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.",
error="agent_not_found",
session_id=session_id,
)
# Build the update request with context
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(update_request, current_agent)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent editing is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if LLM returned clarifying questions
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
message=(
"I need some more information about the changes. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
],
session_id=session_id,
)
# Result is the updated agent JSON
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
agent_description = updated_agent.get("description", "")
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "
f"The agent now has {node_count} blocks. "
f"Review it and call edit_agent with save=true to save the changes."
),
agent_json=updated_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
# Save to library (creates a new version)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
try:
created_graph, library_agent = await save_agent_to_library(
updated_agent, user_id, is_update=True
)
return AgentSavedResponse(
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
return ErrorResponse(
message=f"Failed to save the updated agent: {str(e)}",
error="save_failed",
details={"exception": str(e)},
session_id=session_id,
)

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool

View File

@@ -0,0 +1,193 @@
import logging
from typing import Any
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
from backend.api.features.chat.tools.models import (
BlockInfoSummary,
BlockInputFieldInfo,
BlockListResponse,
ErrorResponse,
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import get_block
logger = logging.getLogger(__name__)
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@property
def name(self) -> str:
return "find_block"
@property
def description(self) -> str:
return (
"Search for available blocks by name or description. "
"Blocks are reusable components that perform specific tasks like "
"sending emails, making API calls, processing text, etc. "
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
"The response includes each block's id, required_inputs, and input_schema."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find blocks by name or description. "
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
Args:
user_id: User ID (required)
session: Chat session
query: Search query
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
try:
# Search for blocks using hybrid search
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=10,
)
if not results:
return NoResultsResponse(
message=f"No blocks found for '{query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
"Check spelling of technical terms",
],
session_id=session_id,
)
# Enrich results with full block information
blocks: list[BlockInfoSummary] = []
for result in results:
block_id = result["content_id"]
block = get_block(block_id)
# Skip disabled blocks
if block and not block.disabled:
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
)
if not blocks:
return NoResultsResponse(
message=f"No blocks found for '{query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
],
session_id=session_id,
)
return BlockListResponse(
message=(
f"Found {len(blocks)} block(s) matching '{query}'. "
"To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema."
),
blocks=blocks,
count=len(blocks),
query=query,
session_id=session_id,
)
except Exception as e:
logger.error(f"Error searching blocks: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search blocks",
error=str(e),
session_id=session_id,
)

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool
@@ -19,13 +19,9 @@ class FindLibraryAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"When creating agents with sub-agent composition, use this to get "
"the agent's graph_id, graph_version, input_schema, and output_schema "
"needed for AgentExecutorBlock nodes. "
"Omit the query to list all agents."
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
)
@property
@@ -35,13 +31,10 @@ class FindLibraryAgentTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find agents by name or description. "
"Omit to list all agents in the library."
),
"description": "Search query to find agents by name or description.",
},
},
"required": [],
"required": ["query"],
}
@property
@@ -52,7 +45,7 @@ class FindLibraryAgentTool(BaseTool):
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=(kwargs.get("query") or "").strip(),
query=kwargs.get("query", "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -4,10 +4,13 @@ import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocPageResponse,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,382 @@
"""Pydantic models for tool responses."""
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from backend.data.model import CredentialsMetaInput
class ResponseType(str, Enum):
"""Types of tool responses."""
AGENTS_FOUND = "agents_found"
AGENT_DETAILS = "agent_details"
SETUP_REQUIREMENTS = "setup_requirements"
EXECUTION_STARTED = "execution_started"
NEED_LOGIN = "need_login"
ERROR = "error"
NO_RESULTS = "no_results"
AGENT_OUTPUT = "agent_output"
UNDERSTANDING_UPDATED = "understanding_updated"
AGENT_PREVIEW = "agent_preview"
AGENT_SAVED = "agent_saved"
CLARIFICATION_NEEDED = "clarification_needed"
BLOCK_LIST = "block_list"
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Base response model
class ToolResponseBase(BaseModel):
"""Base model for all tool responses."""
type: ResponseType
message: str
session_id: str | None = None
# Agent discovery models
class AgentInfo(BaseModel):
"""Information about an agent."""
id: str
name: str
description: str
source: str = Field(description="marketplace or library")
in_library: bool = False
creator: str | None = None
category: str | None = None
rating: float | None = None
runs: int | None = None
is_featured: bool | None = None
status: str | None = None
can_access_graph: bool | None = None
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
class AgentsFoundResponse(ToolResponseBase):
"""Response for find_agent tool."""
type: ResponseType = ResponseType.AGENTS_FOUND
title: str = "Available Agents"
agents: list[AgentInfo]
count: int
name: str = "agents_found"
class NoResultsResponse(ToolResponseBase):
"""Response when no agents found."""
type: ResponseType = ResponseType.NO_RESULTS
suggestions: list[str] = []
name: str = "no_results"
# Agent details models
class InputField(BaseModel):
"""Input field specification."""
name: str
type: str = "string"
description: str = ""
required: bool = False
default: Any | None = None
options: list[Any] | None = None
format: str | None = None
class ExecutionOptions(BaseModel):
"""Available execution options for an agent."""
manual: bool = True
scheduled: bool = True
webhook: bool = False
class AgentDetails(BaseModel):
"""Detailed agent information."""
id: str
name: str
description: str
in_library: bool = False
inputs: dict[str, Any] = {}
credentials: list[CredentialsMetaInput] = []
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
trigger_info: dict[str, Any] | None = None
class AgentDetailsResponse(ToolResponseBase):
"""Response for get_details action."""
type: ResponseType = ResponseType.AGENT_DETAILS
agent: AgentDetails
user_authenticated: bool = False
graph_id: str | None = None
graph_version: int | None = None
# Setup info models
class UserReadiness(BaseModel):
"""User readiness status."""
has_all_credentials: bool = False
missing_credentials: dict[str, Any] = {}
ready_to_run: bool = False
class SetupInfo(BaseModel):
"""Complete setup information."""
agent_id: str
agent_name: str
requirements: dict[str, list[Any]] = Field(
default_factory=lambda: {
"credentials": [],
"inputs": [],
"execution_modes": [],
},
)
user_readiness: UserReadiness = Field(default_factory=UserReadiness)
class SetupRequirementsResponse(ToolResponseBase):
"""Response for validate action."""
type: ResponseType = ResponseType.SETUP_REQUIREMENTS
setup_info: SetupInfo
graph_id: str | None = None
graph_version: int | None = None
# Execution models
class ExecutionStartedResponse(ToolResponseBase):
"""Response for run/schedule actions."""
type: ResponseType = ResponseType.EXECUTION_STARTED
execution_id: str
graph_id: str
graph_name: str
library_agent_id: str | None = None
library_agent_link: str | None = None
status: str = "QUEUED"
# Auth/error models
class NeedLoginResponse(ToolResponseBase):
"""Response when login is needed."""
type: ResponseType = ResponseType.NEED_LOGIN
agent_info: dict[str, Any] | None = None
class ErrorResponse(ToolResponseBase):
"""Response for errors."""
type: ResponseType = ResponseType.ERROR
error: str | None = None
details: dict[str, Any] | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
execution_id: str
status: str
started_at: datetime | None = None
ended_at: datetime | None = None
outputs: dict[str, list[Any]]
inputs_summary: dict[str, Any] | None = None
class AgentOutputResponse(ToolResponseBase):
"""Response for agent_output tool."""
type: ResponseType = ResponseType.AGENT_OUTPUT
agent_name: str
agent_id: str
library_agent_id: str | None = None
library_agent_link: str | None = None
execution: ExecutionOutputInfo | None = None
available_executions: list[dict[str, Any]] | None = None
total_executions: int = 0
# Business understanding models
class UnderstandingUpdatedResponse(ToolResponseBase):
"""Response for add_understanding tool."""
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
updated_fields: list[str] = Field(default_factory=list)
current_understanding: dict[str, Any] = Field(default_factory=dict)
# Agent generation models
class ClarifyingQuestion(BaseModel):
"""A question that needs user clarification."""
question: str
keyword: str
example: str | None = None
class AgentPreviewResponse(ToolResponseBase):
"""Response for previewing a generated agent before saving."""
type: ResponseType = ResponseType.AGENT_PREVIEW
agent_json: dict[str, Any]
agent_name: str
description: str
node_count: int
link_count: int = 0
class AgentSavedResponse(ToolResponseBase):
"""Response when an agent is saved to the library."""
type: ResponseType = ResponseType.AGENT_SAVED
agent_id: str
agent_name: str
library_agent_id: str
library_agent_link: str
agent_page_link: str # Link to the agent builder/editor page
class ClarificationNeededResponse(ToolResponseBase):
"""Response when the LLM needs more information from the user."""
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
questions: list[ClarifyingQuestion] = Field(default_factory=list)
# Documentation search models
class DocSearchResult(BaseModel):
"""A single documentation search result."""
title: str
path: str
section: str
snippet: str # Short excerpt for UI display
score: float
doc_url: str | None = None
class DocSearchResultsResponse(ToolResponseBase):
"""Response for search_docs tool."""
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
results: list[DocSearchResult]
count: int
query: str
class DocPageResponse(ToolResponseBase):
"""Response for get_doc_page tool."""
type: ResponseType = ResponseType.DOC_PAGE
title: str
path: str
content: str # Full document content
doc_url: str | None = None
# Block models
class BlockInputFieldInfo(BaseModel):
"""Information about a block input field."""
name: str
type: str
description: str = ""
required: bool = False
default: Any | None = None
class BlockInfoSummary(BaseModel):
"""Summary of a block for search results."""
id: str
name: str
description: str
categories: list[str]
input_schema: dict[str, Any]
output_schema: dict[str, Any]
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of required input fields for this block",
)
class BlockListResponse(ToolResponseBase):
"""Response for find_block tool."""
type: ResponseType = ResponseType.BLOCK_LIST
blocks: list[BlockInfoSummary]
count: int
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the required fields from input_schema."
)
class BlockOutputResponse(ToolResponseBase):
"""Response for run_block tool."""
type: ResponseType = ResponseType.BLOCK_OUTPUT
block_id: str
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str

View File

@@ -5,13 +5,16 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
from backend.data.execution import ExecutionStatus
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import (
track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
from backend.executor import utils as execution_utils
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -21,17 +24,12 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .execution_utils import get_execution_outputs, wait_for_execution
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
AgentOutputResponse,
ErrorResponse,
ExecutionOptions,
ExecutionOutputInfo,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
@@ -70,7 +68,6 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
@field_validator(
"username_agent_slug",
@@ -152,14 +149,6 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "IANA timezone for schedule (default: UTC)",
},
"wait_for_result": {
"type": "integer",
"description": (
"Max seconds to wait for execution to complete (0-300). "
"If >0, blocks until the execution finishes or times out. "
"Returns execution outputs when complete."
),
},
},
"required": [],
}
@@ -209,7 +198,7 @@ class RunAgentTool(BaseTool):
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db().get_library_agent(
library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
@@ -218,7 +207,9 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
# Get the graph from the library agent
graph = await graph_db().get_graph(
from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
@@ -269,7 +260,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema),
"inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -282,22 +273,6 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide
@@ -354,7 +329,6 @@ class RunAgentTool(BaseTool):
graph=graph,
graph_credentials=graph_credentials,
inputs=params.inputs,
wait_for_result=params.wait_for_result,
)
except NotFoundError as e:
@@ -378,6 +352,22 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -391,7 +381,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema)
inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]
@@ -438,9 +428,8 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
wait_for_result: int = 0,
) -> ToolResponseBase:
"""Execute an agent immediately, optionally waiting for completion."""
"""Execute an agent immediately."""
session_id = session.session_id
# Check rate limits
@@ -477,91 +466,6 @@ class RunAgentTool(BaseTool):
)
library_agent_link = f"/library/agents/{library_agent.id}"
# If wait_for_result is requested, wait for execution to complete
if wait_for_result > 0:
logger.info(
f"Waiting up to {wait_for_result}s for execution {execution.id}"
)
completed = await wait_for_execution(
user_id=user_id,
graph_id=library_agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_for_result,
)
if completed and completed.status == ExecutionStatus.COMPLETED:
outputs = get_execution_outputs(completed)
return AgentOutputResponse(
message=(
f"Agent '{library_agent.name}' completed successfully. "
f"View at {library_agent_link}."
),
session_id=session_id,
agent_name=library_agent.name,
agent_id=library_agent.graph_id,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
execution=ExecutionOutputInfo(
execution_id=execution.id,
status=completed.status.value,
started_at=completed.started_at,
ended_at=completed.ended_at,
outputs=outputs or {},
),
)
elif completed and completed.status == ExecutionStatus.FAILED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution failed. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.TERMINATED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution was terminated. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.REVIEW:
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is awaiting human review. "
f"Check at {library_agent_link}."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=ExecutionStatus.REVIEW.value,
)
else:
status = completed.status.value if completed else "unknown"
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is still {status} after "
f"{wait_for_result}s. Check results later at "
f"{library_agent_link}. "
f"Use view_agent_output with wait_if_running to check again."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=status,
)
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' execution started successfully. "
@@ -616,7 +520,7 @@ class RunAgentTool(BaseTool):
library_agent = await get_or_create_library_agent(graph, user_id)
# Get user timezone
user = await user_db().get_user_by_id(user_id)
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
# Create schedule

View File

@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name
assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -5,35 +5,24 @@ import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.blocks import BlockType, get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import get_inputs_from_schema
from .models import (
BlockDetails,
BlockDetailsResponse,
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
@@ -51,8 +40,8 @@ class RunBlockTool(BaseTool):
"Execute a specific block with the provided input data. "
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. "
"On first attempt (without input_data), returns detailed schema showing "
"required inputs and outputs. Then call again with proper input_data to execute."
"Use the 'id' from find_block results and provide input_data "
"matching the block's required_inputs."
)
@property
@@ -67,29 +56,80 @@ class RunBlockTool(BaseTool):
"NEVER guess this - always get it from find_block first."
),
},
"block_name": {
"type": "string",
"description": (
"The block's human-readable name from find_block results. "
"Used for display purposes in the UI."
),
},
"input_data": {
"type": "object",
"description": (
"Input values for the block. "
"First call with empty {} to see the block's schema, "
"then call again with proper values to execute."
"Input values for the block. Use the 'required_inputs' field "
"from find_block to see what fields are needed."
),
},
},
"required": ["block_id", "block_name", "input_data"],
"required": ["block_id", "input_data"],
}
@property
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
@@ -144,61 +184,13 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
# Provide actionable guidance for blocks with dedicated tools
if block.block_type == BlockType.MCP_TOOL:
hint = (
" Use the `run_mcp_tool` tool instead — it handles "
"MCP server discovery, authentication, and execution."
)
elif block.block_type == BlockType.AGENT:
hint = " Use the `run_agent` tool instead."
else:
hint = " This block is designed for use within graphs only."
return ErrorResponse(
message=f"Block '{block.name}' cannot be run directly.{hint}",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
# Get block schemas for details/validation
try:
input_schema: dict[str, Any] = block.input_schema.jsonschema()
except Exception as e:
logger.warning(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
return ErrorResponse(
message=f"Block '{block.name}' has an invalid input schema",
error=str(e),
session_id=session_id,
)
try:
output_schema: dict[str, Any] = block.output_schema.jsonschema()
except Exception as e:
logger.warning(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
return ErrorResponse(
message=f"Block '{block.name}' has an invalid output schema",
error=str(e),
session_id=session_id,
)
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block
)
if missing_credentials:
# Return setup requirements response with missing credentials
@@ -232,56 +224,9 @@ class RunBlockTool(BaseTool):
graph_version=None,
)
# Check if this is a first attempt (required inputs missing)
# Return block details so user can see what inputs are needed
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
required_keys = set(input_schema.get("required", []))
required_non_credential_keys = required_keys - credentials_fields
provided_input_keys = set(input_data.keys()) - credentials_fields
# Check for unknown input fields
valid_fields = (
set(input_schema.get("properties", {}).keys()) - credentials_fields
)
unrecognized_fields = provided_input_keys - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Block was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=input_schema,
)
# Show details when not all required non-credential inputs are provided
if not (required_non_credential_keys <= provided_input_keys):
# Get credentials info for the response
credentials_meta = []
for field_name, cred_meta in matched_credentials.items():
credentials_meta.append(cred_meta)
return BlockDetailsResponse(
message=(
f"Block '{block.name}' details. "
"Provide input_data matching the inputs schema to execute the block."
),
session_id=session_id,
block=BlockDetails(
id=block_id,
name=block.name,
description=block.description or "",
inputs=input_schema,
outputs=output_schema,
credentials=credentials_meta,
),
user_authenticated=True,
)
try:
# Get or create user's workspace for CoPilot file operations
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
@@ -373,75 +318,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
resolved: dict[str, CredentialsFieldInfo] = {}
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved
return inputs_list

View File

@@ -5,17 +5,16 @@ from typing import Any
from prisma.enums import ContentType
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool
from .models import (
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocSearchResult,
DocSearchResultsResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
logger = logging.getLogger(__name__)
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await search().unified_hybrid_search(
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.DOCUMENTATION],
page=1,

View File

@@ -3,18 +3,13 @@
import logging
from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model
from backend.data.db_accessors import library_db, store_db
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -38,15 +33,20 @@ async def fetch_graph_from_store_slug(
Raises:
DatabaseError: If there's a database error during lookup.
"""
sdb = store_db()
try:
store_agent = await sdb.get_store_agent_details(username, agent_name)
store_agent = await store_db.get_store_agent_details(username, agent_name)
except NotFoundError:
return None, None
# Get the graph from store listing version
graph = await sdb.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False
graph_meta = await store_db.get_available_graph(
store_agent.store_listing_version_id
)
graph = await graph_db.get_graph(
graph_id=graph_meta.id,
version=graph_meta.version,
user_id=None, # Public access
include_subgraphs=True,
)
return graph, store_agent
@@ -123,7 +123,7 @@ def build_missing_credentials_from_graph(
return {
field_key: _serialize_missing_credential(field_key, field_info)
for field_key, (field_info, _, _) in aggregated_fields.items()
for field_key, (field_info, _node_fields) in aggregated_fields.items()
if field_key not in matched_keys
}
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
Returns:
LibraryAgent instance
"""
existing = await library_db().get_library_agent_by_graph_id(
existing = await library_db.get_library_agent_by_graph_id(
graph_id=graph.id, user_id=user_id
)
if existing:
return existing
library_agents = await library_db().create_library_agent(
library_agents = await library_db.create_library_agent(
graph=graph,
user_id=user_id,
create_library_agents_for_sub_graphs=False,
@@ -225,99 +225,6 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -357,28 +264,15 @@ async def match_user_credentials_to_graph(
# provider is in the set of acceptable providers.
for credential_field_name, (
credential_requirements,
_,
_,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, scopes, and host/URL
# Find first matching credential by provider and type
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and (
cred.type != "oauth2"
or _credential_has_required_scopes(cred, credential_requirements)
)
and (
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
),
None,
)
@@ -402,17 +296,10 @@ async def match_user_credentials_to_graph(
f"{credential_field_name} (validation failed: {e})"
)
else:
# Build a helpful error message including scope requirements
error_parts = [
f"provider in {list(credential_requirements.provider)}",
f"type in {list(credential_requirements.supported_types)}",
]
if credential_requirements.required_scopes:
error_parts.append(
f"scopes including {list(credential_requirements.required_scopes)}"
)
missing_creds.append(
f"{credential_field_name} (requires {', '.join(error_parts)})"
f"{credential_field_name} "
f"(requires provider in {list(credential_requirements.provider)}, "
f"type in {list(credential_requirements.supported_types)})"
)
logger.info(
@@ -422,49 +309,6 @@ async def match_user_credentials_to_graph(
return graph_credentials_inputs, missing_creds
def _credential_has_required_scopes(
credential: OAuth2Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an OAuth2 credential has all the scopes required by the input."""
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
return set(credential.scopes).issuperset(requirements.required_scopes)
def _credential_is_for_host(
credential: HostScopedCredentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if a host-scoped credential matches the host required by the input."""
# We need to know the host to match host-scoped credentials to.
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
# to discriminator_values. No discriminator_values -> no host to match against.
if not requirements.discriminator_values:
return True
# Check that credential host matches required host.
# Host-scoped credential inputs are grouped by host, so any item from the set works.
return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -0,0 +1,620 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -22,7 +22,6 @@ from backend.data.human_review import (
)
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.data.workspace import get_or_create_workspace
from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -322,13 +321,10 @@ async def process_review_action(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
workspace = await get_or_create_workspace(user_id)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
workspace_id=workspace.id,
)
await add_graph_execution(

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id
from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security,
status,
)
from pydantic import BaseModel, Field, SecretStr, model_validator
from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
default=None, description="Host pattern for host-scoped credentials"
)
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
@@ -211,7 +179,9 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)),
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -229,7 +199,7 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -352,11 +322,7 @@ async def delete_credentials(
tokens_revoked = None
if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value):
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider)
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked)

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ import prisma.enums
import prisma.models
import pytest
import backend.api.features.store.exceptions
from backend.data.db import connect
from backend.data.includes import library_agent_include
@@ -143,7 +144,6 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
@@ -178,6 +178,7 @@ async def test_add_agent_to_library(mocker):
"agentGraphVersion": 1,
}
},
include={"AgentGraph": True},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
)
# Call function and verify exception
with pytest.raises(db.NotFoundError):
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly

View File

@@ -1,10 +0,0 @@
class FolderValidationError(Exception):
"""Raised when folder operations fail validation."""
pass
class FolderAlreadyExistsError(FolderValidationError):
"""Raised when a folder with the same name already exists in the location."""
pass

View File

@@ -6,13 +6,9 @@ import prisma.enums
import prisma.models
import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import (
CredentialsMetaInput,
GraphInput,
is_credentials_field_name,
)
from backend.util.json import loads as json_loads
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.util.models import Pagination
if TYPE_CHECKING:
@@ -20,99 +16,10 @@ if TYPE_CHECKING:
class LibraryAgentStatus(str, Enum):
COMPLETED = "COMPLETED"
HEALTHY = "HEALTHY"
WAITING = "WAITING"
ERROR = "ERROR"
# === Folder Models ===
class LibraryFolder(pydantic.BaseModel):
"""Represents a folder for organizing library agents."""
id: str
user_id: str
name: str
icon: str | None = None
color: str | None = None
parent_id: str | None = None
created_at: datetime.datetime
updated_at: datetime.datetime
agent_count: int = 0 # Direct agents in folder
subfolder_count: int = 0 # Direct child folders
@staticmethod
def from_db(
folder: prisma.models.LibraryFolder,
agent_count: int = 0,
subfolder_count: int = 0,
) -> "LibraryFolder":
"""Factory method that constructs a LibraryFolder from a Prisma model."""
return LibraryFolder(
id=folder.id,
user_id=folder.userId,
name=folder.name,
icon=folder.icon,
color=folder.color,
parent_id=folder.parentId,
created_at=folder.createdAt,
updated_at=folder.updatedAt,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
class LibraryFolderTree(LibraryFolder):
"""Folder with nested children for tree view."""
children: list["LibraryFolderTree"] = []
class FolderCreateRequest(pydantic.BaseModel):
"""Request model for creating a folder."""
name: str = pydantic.Field(..., min_length=1, max_length=100)
icon: str | None = None
color: str | None = pydantic.Field(
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
)
parent_id: str | None = None
class FolderUpdateRequest(pydantic.BaseModel):
"""Request model for updating a folder."""
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
icon: str | None = None
color: str | None = None
class FolderMoveRequest(pydantic.BaseModel):
"""Request model for moving a folder to a new parent."""
target_parent_id: str | None = None # None = move to root
class BulkMoveAgentsRequest(pydantic.BaseModel):
"""Request model for moving multiple agents to a folder."""
agent_ids: list[str]
folder_id: str | None = None # None = move to root
class FolderListResponse(pydantic.BaseModel):
"""Response schema for a list of folders."""
folders: list[LibraryFolder]
pagination: Pagination
class FolderTreeResponse(pydantic.BaseModel):
"""Response schema for folder tree structure."""
tree: list[LibraryFolderTree]
COMPLETED = "COMPLETED" # All runs completed
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
WAITING = "WAITING" # Agent is queued or waiting to start
ERROR = "ERROR" # Agent is in an error state
class MarketplaceListingCreator(pydantic.BaseModel):
@@ -132,30 +39,6 @@ class MarketplaceListing(pydantic.BaseModel):
creator: MarketplaceListingCreator
class RecentExecution(pydantic.BaseModel):
"""Summary of a recent execution for quality assessment.
Used by the LLM to understand the agent's recent performance with specific examples
rather than just aggregate statistics.
"""
status: str
correctness_score: float | None = None
activity_summary: str | None = None
def _parse_settings(settings: dict | str | None) -> GraphSettings:
"""Parse settings from database, handling both dict and string formats."""
if settings is None:
return GraphSettings()
try:
if isinstance(settings, str):
settings = json_loads(settings)
return GraphSettings.model_validate(settings)
except Exception:
return GraphSettings()
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -165,7 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -181,7 +64,7 @@ class LibraryAgent(pydantic.BaseModel):
description: str
instructions: str | None = None
input_schema: dict[str, Any]
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, Any]
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
description="Input schema for credentials required by the agent",
@@ -198,22 +81,25 @@ class LibraryAgent(pydantic.BaseModel):
)
trigger_setup_info: Optional[GraphTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
execution_count: int = 0
success_rate: float | None = None
avg_correctness_score: float | None = None
recent_executions: list[RecentExecution] = pydantic.Field(
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
folder_name: str | None = None # Denormalized for display
# Whether the user can access the underlying graph
can_access_graph: bool
# Indicates if this agent is the latest version
is_latest_version: bool
# Whether the agent is marked as favorite by the user
is_favorite: bool
# Recommended schedule cron (from marketplace agents)
recommended_schedule_cron: str | None = None
# User-specific settings for this library agent
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
# Marketplace listing information if the agent has been published
marketplace_listing: Optional["MarketplaceListing"] = None
@staticmethod
@@ -237,6 +123,7 @@ class LibraryAgent(pydantic.BaseModel):
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
@@ -249,6 +136,7 @@ class LibraryAgent(pydantic.BaseModel):
creator_name = agent.Creator.name or "Unknown"
creator_image_url = agent.Creator.avatarUrl or ""
# Logic to calculate status and new_output
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
@@ -257,55 +145,13 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
success_count = sum(
1
for e in executions
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
)
success_rate = (success_count / execution_count) * 100
correctness_scores = []
for e in executions:
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
correctness_scores.append(float(score))
if correctness_scores:
avg_correctness_score = sum(correctness_scores) / len(
correctness_scores
)
recent_executions: list[RecentExecution] = []
for e in executions:
exec_score: float | None = None
exec_summary: str | None = None
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
exec_score = float(score)
summary = e.stats.get("activity_status")
if summary is not None and isinstance(summary, str):
exec_summary = summary
exec_status = (
e.executionStatus.value
if hasattr(e.executionStatus, "value")
else str(e.executionStatus)
)
recent_executions.append(
RecentExecution(
status=exec_status,
correctness_score=exec_score,
activity_summary=exec_summary,
)
)
# Check if user can access the graph
can_access_graph = agent.AgentGraph.userId == agent.userId
# Hard-coded to True until a method to check is implemented
is_latest_version = True
# Build marketplace_listing if available
marketplace_listing_data = None
if store_listing and store_listing.ActiveVersion and profile:
creator_data = MarketplaceListingCreator(
@@ -344,17 +190,11 @@ class LibraryAgent(pydantic.BaseModel):
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info,
new_output=new_output,
execution_count=execution_count,
success_rate=success_rate,
avg_correctness_score=avg_correctness_score,
recent_executions=recent_executions,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
settings=GraphSettings.model_validate(agent.settings),
marketplace_listing=marketplace_listing_data,
)
@@ -380,15 +220,18 @@ def _calculate_agent_status(
if not executions:
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
# Track how many times each execution status appears
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
new_output = False
for execution in executions:
# Check if there's a completed run more recent than `recent_threshold`
if execution.createdAt >= recent_threshold:
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
new_output = True
status_counts[execution.executionStatus] += 1
# Determine the final status based on counts
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
@@ -420,7 +263,7 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_id: str
graph_version: int
inputs: GraphInput
inputs: BlockInput
credentials: dict[str, CredentialsMetaInput]
name: str
@@ -449,7 +292,7 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
Request model used when updating a preset for a library agent.
"""
inputs: Optional[GraphInput] = None
inputs: Optional[BlockInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None
@@ -492,7 +335,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"Webhook must be included in AgentPreset query when webhookId is set"
)
input_data: GraphInput = {}
input_data: BlockInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets:
@@ -564,7 +407,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
settings: Optional[GraphSettings] = pydantic.Field(
default=None, description="User-specific settings for this library agent"
)
folder_id: Optional[str] = pydantic.Field(
default=None,
description="Folder ID to move agent to (None to move to root)",
)

View File

@@ -1,11 +1,9 @@
import fastapi
from .agents import router as agents_router
from .folders import router as folders_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(folders_router)
router.include_router(agents_router)

View File

@@ -41,14 +41,6 @@ async def list_library_agents(
ge=1,
description="Number of agents per page (must be >= 1)",
),
folder_id: Optional[str] = Query(
None,
description="Filter by folder ID",
),
include_root_only: bool = Query(
False,
description="Only return agents without a folder (root-level agents)",
),
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
@@ -59,8 +51,6 @@ async def list_library_agents(
sort_by=sort_by,
page=page,
page_size=page_size,
folder_id=folder_id,
include_root_only=include_root_only,
)
@@ -178,7 +168,6 @@ async def update_library_agent(
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
folder_id=payload.folder_id,
)

View File

@@ -1,287 +0,0 @@
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Query, Security, status
from fastapi.responses import Response
from .. import db as library_db
from .. import model as library_model
router = APIRouter(
prefix="/folders",
tags=["library", "folders", "private"],
dependencies=[Security(autogpt_auth_lib.requires_user)],
)
@router.get(
"",
summary="List Library Folders",
response_model=library_model.FolderListResponse,
responses={
200: {"description": "List of folders"},
500: {"description": "Server error"},
},
)
async def list_folders(
user_id: str = Security(autogpt_auth_lib.get_user_id),
parent_id: Optional[str] = Query(
None,
description="Filter by parent folder ID. If not provided, returns root-level folders.",
),
include_relations: bool = Query(
True,
description="Include agent and subfolder relations (for counts)",
),
) -> library_model.FolderListResponse:
"""
List folders for the authenticated user.
Args:
user_id: ID of the authenticated user.
parent_id: Optional parent folder ID to filter by.
include_relations: Whether to include agent and subfolder relations for counts.
Returns:
A FolderListResponse containing folders.
"""
folders = await library_db.list_folders(
user_id=user_id,
parent_id=parent_id,
include_relations=include_relations,
)
return library_model.FolderListResponse(
folders=folders,
pagination=library_model.Pagination(
total_items=len(folders),
total_pages=1,
current_page=1,
page_size=len(folders),
),
)
@router.get(
"/tree",
summary="Get Folder Tree",
response_model=library_model.FolderTreeResponse,
responses={
200: {"description": "Folder tree structure"},
500: {"description": "Server error"},
},
)
async def get_folder_tree(
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.FolderTreeResponse:
"""
Get the full folder tree for the authenticated user.
Args:
user_id: ID of the authenticated user.
Returns:
A FolderTreeResponse containing the nested folder structure.
"""
tree = await library_db.get_folder_tree(user_id=user_id)
return library_model.FolderTreeResponse(tree=tree)
@router.get(
"/{folder_id}",
summary="Get Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder details"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def get_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Get a specific folder.
Args:
folder_id: ID of the folder to retrieve.
user_id: ID of the authenticated user.
Returns:
The requested LibraryFolder.
"""
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
@router.post(
"",
summary="Create Folder",
status_code=status.HTTP_201_CREATED,
response_model=library_model.LibraryFolder,
responses={
201: {"description": "Folder created successfully"},
400: {"description": "Validation error"},
404: {"description": "Parent folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def create_folder(
payload: library_model.FolderCreateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Create a new folder.
Args:
payload: The folder creation request.
user_id: ID of the authenticated user.
Returns:
The created LibraryFolder.
"""
return await library_db.create_folder(
user_id=user_id,
name=payload.name,
parent_id=payload.parent_id,
icon=payload.icon,
color=payload.color,
)
@router.patch(
"/{folder_id}",
summary="Update Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder updated successfully"},
400: {"description": "Validation error"},
404: {"description": "Folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def update_folder(
folder_id: str,
payload: library_model.FolderUpdateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Update a folder's properties.
Args:
folder_id: ID of the folder to update.
payload: The folder update request.
user_id: ID of the authenticated user.
Returns:
The updated LibraryFolder.
"""
return await library_db.update_folder(
folder_id=folder_id,
user_id=user_id,
name=payload.name,
icon=payload.icon,
color=payload.color,
)
@router.post(
"/{folder_id}/move",
summary="Move Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder moved successfully"},
400: {"description": "Validation error (circular reference)"},
404: {"description": "Folder or target parent not found"},
409: {"description": "Folder name conflict in target location"},
500: {"description": "Server error"},
},
)
async def move_folder(
folder_id: str,
payload: library_model.FolderMoveRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Move a folder to a new parent.
Args:
folder_id: ID of the folder to move.
payload: The move request with target parent.
user_id: ID of the authenticated user.
Returns:
The moved LibraryFolder.
"""
return await library_db.move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=payload.target_parent_id,
)
@router.delete(
"/{folder_id}",
summary="Delete Folder",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Folder deleted successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def delete_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> Response:
"""
Soft-delete a folder and all its contents.
Args:
folder_id: ID of the folder to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
"""
await library_db.delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
# === Bulk Agent Operations ===
@router.post(
"/agents/bulk-move",
summary="Bulk Move Agents",
response_model=list[library_model.LibraryAgent],
responses={
200: {"description": "Agents moved successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def bulk_move_agents(
payload: library_model.BulkMoveAgentsRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> list[library_model.LibraryAgent]:
"""
Move multiple agents to a folder.
Args:
payload: The bulk move request with agent IDs and target folder.
user_id: ID of the authenticated user.
Returns:
The updated LibraryAgents.
"""
return await library_db.bulk_move_agents_to_folder(
agent_ids=payload.agent_ids,
folder_id=payload.folder_id,
user_id=user_id,
)

View File

@@ -115,8 +115,6 @@ async def test_get_library_agents_success(
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
folder_id=None,
include_root_only=False,
)

View File

@@ -1,511 +0,0 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field, SecretStr
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
server_host,
)
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
best_cred = await auto_lookup_mcp_credential(
user_id, normalize_mcp_url(request.server_url)
)
if best_cred:
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or server_host(request.server_url)
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize the URL so that credentials stored here are matched consistently
# by auto_lookup_mcp_credential (which also uses normalized URLs).
server_url = normalize_mcp_url(request.server_url)
client = MCPClient(server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", server_url)
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid authorization server URL in metadata: {e}",
)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
except ValueError:
pass # Skip registration, fall back to default client_id
else:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = server_host(meta["server_url"])
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
"Removed old MCP credential %s for %s",
old.id,
server_host(meta["server_url"]),
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Bearer Token ======================== #
class MCPStoreTokenRequest(BaseModel):
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
server_url: str = Field(
description="MCP server URL the token authenticates against"
)
token: SecretStr = Field(
min_length=1, description="Bearer token / API key for the MCP server"
)
@router.post(
"/token",
summary="Store a bearer token for an MCP server",
)
async def mcp_store_token(
request: MCPStoreTokenRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Store a manually provided bearer token as an MCP credential.
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
``run_mcp_tool`` calls will automatically pick up the token via
``_auto_lookup_credential``.
"""
token = request.token.get_secret_value().strip()
if not token:
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize URL so trailing-slash variants match existing credentials.
server_url = normalize_mcp_url(request.server_url)
hostname = server_host(server_url)
# Collect IDs of old credentials to clean up after successful create.
old_cred_ids: list[str] = []
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
old_cred_ids = [
old.id
for old in old_creds
if isinstance(old, OAuth2Credentials)
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
== server_url
]
except Exception:
logger.debug("Could not query old MCP token credentials", exc_info=True)
credentials = OAuth2Credentials(
provider=ProviderName.MCP.value,
title=f"MCP: {hostname}",
access_token=SecretStr(token),
scopes=[],
metadata={"mcp_server_url": server_url},
)
await creds_manager.create(user_id, credentials)
# Only delete old credentials after the new one is safely stored.
for old_id in old_cred_ids:
try:
await creds_manager.store.delete_creds_by_id(user_id, old_id)
except Exception:
logger.debug("Could not clean up old MCP token credential", exc_info=True)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=hostname,
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(
"Dynamic client registration failed for %s: %s", server_host(server_url), e
)
return None

View File

@@ -1,572 +0,0 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from pydantic import SecretStr
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.data.model import OAuth2Credentials
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
):
yield
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=stored_cred,
),
):
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
):
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()
class TestStoreToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_success(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
mock_cm.create = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "my-api-key-123",
},
)
assert response.status_code == 200
data = response.json()
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
assert data["host"] == "mcp.example.com"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_blank_rejected(self, client):
"""Blank token string (after stripping) should return 422."""
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": " ",
},
)
# Pydantic min_length=1 catches the whitespace-only token
assert response.status_code == 422
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_replaces_old_credential(self, client):
old_cred = OAuth2Credentials(
provider="mcp",
title="MCP: mcp.example.com",
access_token=SecretStr("old-token"),
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
mock_cm.create = AsyncMock()
mock_cm.store.delete_creds_by_id = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "new-token",
},
)
assert response.status_code == 200
mock_cm.store.delete_creds_by_id.assert_called_once_with(
"test-user-id", old_cred.id
)
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/discover-tools",
json={"server_url": "http://localhost/mcp"},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
response = await client.post(
"/oauth/login",
json={"server_url": "http://10.0.0.1/mcp"},
)
assert response.status_code == 400
assert "blocked private ip" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/token",
json={
"server_url": "http://127.0.0.1/mcp",
"token": "some-token",
},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()

View File

@@ -5,8 +5,8 @@ from typing import Optional
import aiohttp
from fastapi import HTTPException
from backend.blocks import get_block
from backend.data import graph as graph_db
from backend.data.block import get_block
from backend.util.settings import Settings
from .models import ApiResponse, ChatRequest, GraphData

View File

@@ -1,3 +1,5 @@
from typing import Literal
from backend.util.cache import cached
from . import db as store_db
@@ -21,7 +23,7 @@ def clear_all_caches():
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: store_db.StoreAgentsSortOptions | None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
search_query: str | None,
category: str | None,
page: int,
@@ -55,7 +57,7 @@ async def _get_cached_agent_details(
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: store_db.StoreCreatorsSortOptions | None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
page: int,
page_size: int,
):
@@ -73,4 +75,4 @@ async def _get_cached_store_creators(
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await store_db.get_store_creator(username=username.lower())
return await store_db.get_store_creator_details(username=username.lower())

View File

@@ -9,26 +9,15 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, get_args, get_origin
from typing import Any
from prisma.enums import ContentType
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
@@ -163,7 +152,7 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
from backend.blocks import get_blocks
from backend.data.block import get_blocks
# Get all available blocks
all_blocks = get_blocks()
@@ -199,51 +188,45 @@ class BlockHandler(ContentHandler):
try:
block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if block_instance.name:
if hasattr(block_instance, "name") and block_instance.name:
parts.append(block_instance.name)
if block_instance.description:
if (
hasattr(block_instance, "description")
and block_instance.description
):
parts.append(block_instance.description)
if block_instance.categories:
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block_input_fields.items()
if field_info.description
]
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
for info in credentials_info.values()
for provider in info.provider
]
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block_instance.input_schema.model_fields.values()
[cat.value for cat in categories] if categories else []
)
items.append(
@@ -252,11 +235,8 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": block_instance.name,
"name": getattr(block_instance, "name", ""),
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None, # Blocks are public
)
@@ -269,7 +249,7 @@ class BlockHandler(ContentHandler):
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
from backend.blocks import get_blocks
from backend.data.block import get_blocks
all_blocks = get_blocks()

View File

@@ -82,10 +82,9 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
@@ -94,7 +93,7 @@ async def test_block_handler_get_missing_items(mocker):
mock_existing = []
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -136,7 +135,7 @@ async def test_block_handler_get_stats(mocker):
mock_embedded = [{"count": 2}]
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -310,25 +309,25 @@ async def test_content_handlers_registry():
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
handler = BlockHandler()
# Mock block with empty values (all attributes exist but are falsy)
# Mock block with minimal attributes
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
@@ -353,8 +352,6 @@ async def test_block_handler_skips_failed_blocks():
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()
@@ -363,7 +360,7 @@ async def test_block_handler_skips_failed_blocks():
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.blocks.get_blocks",
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(

Some files were not shown because too many files have changed in this diff Show More