mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
6 Commits
dependabot
...
chore/reac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1090f90d95 | ||
|
|
a7c9a3c5ae | ||
|
|
ec06c1278a | ||
|
|
e2525cb8a8 | ||
|
|
02a3a163e7 | ||
|
|
d9d24dcfe6 |
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: backend-check
|
||||
description: Run the full backend formatting, linting, and test suite. Ensures code quality before commits and PRs. TRIGGER when backend Python code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Backend Check
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Format**: `poetry run format` — runs formatting AND linting. NEVER run ruff/black/isort individually
|
||||
2. **Fix** any remaining errors manually, re-run until clean
|
||||
3. **Test**: `poetry run test` (runs DB setup + pytest). For specific files: `poetry run pytest -s -vvv <test_files>`
|
||||
4. **Snapshots** (if needed): `poetry run pytest path/to/test.py --snapshot-update` — review with `git diff`
|
||||
@@ -1,35 +0,0 @@
|
||||
---
|
||||
name: code-style
|
||||
description: Python code style preferences for the AutoGPT backend. Apply when writing or reviewing Python code. TRIGGER when writing new Python code, reviewing PRs, or refactoring backend code.
|
||||
user-invocable: false
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Code Style
|
||||
|
||||
## Imports
|
||||
|
||||
- **Top-level only** — no local/inner imports. Move all imports to the top of the file.
|
||||
|
||||
## Typing
|
||||
|
||||
- **No duck typing** — avoid `hasattr`, `getattr`, `isinstance` for type dispatch. Use proper typed interfaces, unions, or protocols.
|
||||
- **Pydantic models** over dataclass, namedtuple, or raw dict for structured data.
|
||||
- **No linter suppressors** — avoid `# type: ignore`, `# noqa`, `# pyright: ignore` etc. 99% of the time the right fix is fixing the type/code, not silencing the tool.
|
||||
|
||||
## Code Structure
|
||||
|
||||
- **List comprehensions** over manual loop-and-append.
|
||||
- **Early return** — guard clauses first, avoid deep nesting.
|
||||
- **Flatten inline** — prefer short, concise expressions. Reduce `if/else` chains with direct returns or ternaries when readable.
|
||||
- **Modular functions** — break complex logic into small, focused functions rather than long blocks with nested conditionals.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
Before finishing, always ask:
|
||||
- Can any function be split into smaller pieces?
|
||||
- Is there unnecessary nesting that an early return would eliminate?
|
||||
- Can any loop be a comprehension?
|
||||
- Is there a simpler way to express this logic?
|
||||
@@ -1,16 +0,0 @@
|
||||
---
|
||||
name: frontend-check
|
||||
description: Run the full frontend formatting, linting, and type checking suite. Ensures code quality before commits and PRs. TRIGGER when frontend TypeScript/React code has been modified and needs validation.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Frontend Check
|
||||
|
||||
## Steps (in order)
|
||||
|
||||
1. **Format**: `pnpm format` — NEVER run individual formatters
|
||||
2. **Lint**: `pnpm lint` — fix errors, re-run until clean
|
||||
3. **Types**: `pnpm types` — if it keeps failing after multiple attempts, stop and ask the user
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
name: new-block
|
||||
description: Create a new backend block following the Block SDK Guide. Guides through provider configuration, schema definition, authentication, and testing. TRIGGER when user asks to create a new block, add a new integration, or build a new node for the graph editor.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# New Block Creation
|
||||
|
||||
Read `docs/platform/block-sdk-guide.md` first for the full guide.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Provider config** (if external service): create `_config.py` with `ProviderBuilder`
|
||||
2. **Block file** in `backend/blocks/` (from `autogpt_platform/backend/`):
|
||||
- Generate a UUID once with `uuid.uuid4()`, then **hard-code that string** as `id` (IDs must be stable across imports)
|
||||
- `Input(BlockSchema)` and `Output(BlockSchema)` classes
|
||||
- `async def run` that `yield`s output fields
|
||||
3. **Files**: use `store_media_file()` with `"for_block_output"` for outputs
|
||||
4. **Test**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[MyBlock]' -xvs`
|
||||
5. **Format**: `poetry run format`
|
||||
|
||||
## Rules
|
||||
|
||||
- Analyze interfaces: do inputs/outputs connect well with other blocks in a graph?
|
||||
- Use top-level imports, avoid duck typing
|
||||
- Always use `for_block_output` for block outputs
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
name: openapi-regen
|
||||
description: Regenerate the OpenAPI spec and frontend API client. Starts the backend REST server, fetches the spec, and regenerates the typed frontend hooks. TRIGGER when API routes change, new endpoints are added, or frontend API types are stale.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# OpenAPI Spec Regeneration
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Run end-to-end** in a single shell block (so `REST_PID` persists):
|
||||
```bash
|
||||
cd autogpt_platform/backend && poetry run rest &
|
||||
REST_PID=$!
|
||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && kill $REST_PID && exit 1; done
|
||||
cd ../frontend && pnpm generate:api:force
|
||||
kill $REST_PID
|
||||
pnpm types && pnpm lint && pnpm format
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
- Always use `pnpm generate:api:force` (not `pnpm generate:api`)
|
||||
- Don't manually edit files in `src/app/api/__generated__/`
|
||||
- Generated hooks follow: `use{Method}{Version}{OperationName}`
|
||||
@@ -1,31 +0,0 @@
|
||||
---
|
||||
name: pr-create
|
||||
description: Create a pull request for the current branch. TRIGGER when user asks to create a PR, open a pull request, push changes for review, or submit work for merging.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Create Pull Request
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check for existing PR**: `gh pr view --json url -q .url 2>/dev/null` — if a PR already exists, output its URL and stop
|
||||
2. **Understand changes**: `git status`, `git diff dev...HEAD`, `git log dev..HEAD --oneline`
|
||||
3. **Read PR template**: `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
4. **Draft PR title**: Use conventional commits format (see CLAUDE.md for types and scopes)
|
||||
5. **Fill out PR template** as the body — be thorough in the Changes section
|
||||
6. **Format first** (if relevant changes exist):
|
||||
- Backend: `cd autogpt_platform/backend && poetry run format`
|
||||
- Frontend: `cd autogpt_platform/frontend && pnpm format`
|
||||
- Fix any lint errors, then commit formatting changes before pushing
|
||||
7. **Push**: `git push -u origin HEAD`
|
||||
8. **Create PR**: `gh pr create --base dev`
|
||||
9. **Output** the PR URL
|
||||
|
||||
## Rules
|
||||
|
||||
- Always target `dev` branch
|
||||
- Do NOT run tests — CI will handle that
|
||||
- Use the PR template from `.github/PULL_REQUEST_TEMPLATE.md`
|
||||
@@ -1,51 +0,0 @@
|
||||
---
|
||||
name: pr-review
|
||||
description: Address all open PR review comments systematically. Fetches comments, addresses each one, reacts +1/-1, and replies when clarification is needed. Keeps iterating until all comments are addressed and CI is green. TRIGGER when user shares a PR URL, asks to address review comments, fix PR feedback, or respond to reviewer comments.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Review Comment Workflow
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Find PR**: `gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT`
|
||||
2. **Fetch comments** (all three sources):
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` (top-level reviews)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` (inline review comments)
|
||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` (PR conversation comments)
|
||||
3. **Skip** comments already reacted to by PR author
|
||||
4. **For each unreacted comment**:
|
||||
- Read referenced code, make the fix (or reply if you disagree/need info)
|
||||
- **Inline review comments** (`pulls/{N}/comments`):
|
||||
- React: `gh api repos/.../pulls/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- Reply: `gh api repos/.../pulls/{N}/comments/{ID}/replies -f body="..."`
|
||||
- **PR conversation comments** (`issues/{N}/comments`):
|
||||
- React: `gh api repos/.../issues/comments/{ID}/reactions -f content="+1"` (or `-1`)
|
||||
- No threaded replies — post a new issue comment if needed
|
||||
- **Top-level reviews**: no reaction API — address in code, reply via issue comment if needed
|
||||
5. **Include autogpt-reviewer bot fixes** too
|
||||
6. **Format**: `cd autogpt_platform/backend && poetry run format`, `cd autogpt_platform/frontend && pnpm format`
|
||||
7. **Commit & push**
|
||||
8. **Re-fetch comments** immediately — address any new unreacted ones before waiting on CI
|
||||
9. **Stay productive while CI runs** — don't idle. In priority order:
|
||||
- Run any pending local tests (`poetry run pytest`, e2e, etc.) and fix failures
|
||||
- Address any remaining comments
|
||||
- Only poll `gh pr checks {N}` as the last resort when there's truly nothing left to do
|
||||
10. **If CI fails** — fix, go back to step 6
|
||||
11. **Re-fetch comments again** after CI is green — address anything that appeared while CI was running
|
||||
12. **Done** only when: all comments reacted AND CI is green.
|
||||
|
||||
## CRITICAL: Do Not Stop
|
||||
|
||||
**Loop is: address → format → commit → push → re-check comments → run local tests → wait CI → re-check comments → repeat.**
|
||||
|
||||
Never idle. If CI is running and you have nothing to address, run local tests. Waiting on CI is the last resort.
|
||||
|
||||
## Rules
|
||||
|
||||
- One todo per comment
|
||||
- For inline review comments: reply on existing threads. For PR conversation comments: post a new issue comment (API doesn't support threaded replies)
|
||||
- React to every comment: +1 addressed, -1 disagreed (with explanation)
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
name: worktree-setup
|
||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, generates Prisma client, and optionally starts the app (with port conflict resolution) or runs tests. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||
user-invocable: true
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Worktree Setup
|
||||
|
||||
## Preferred: Use Branchlet
|
||||
|
||||
The repo has a `.branchlet.json` config — it handles env file copying, dependency installation, and Prisma generation automatically.
|
||||
|
||||
```bash
|
||||
npm install -g branchlet # install once
|
||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||
branchlet list --json # list all worktrees
|
||||
```
|
||||
|
||||
## Manual Fallback
|
||||
|
||||
If branchlet isn't available:
|
||||
|
||||
1. `git worktree add ../<RepoName><N> <branch-name>`
|
||||
2. Copy `.env` files: `backend/.env`, `frontend/.env`, `autogpt_platform/.env`, `db/docker/.env`
|
||||
3. Install deps:
|
||||
- `cd autogpt_platform/backend && poetry install && poetry run prisma generate`
|
||||
- `cd autogpt_platform/frontend && pnpm install`
|
||||
|
||||
## Running the App
|
||||
|
||||
Free ports first — backend uses: 8001, 8002, 8003, 8005, 8006, 8007, 8008.
|
||||
|
||||
```bash
|
||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||
done
|
||||
cd <worktree>/autogpt_platform/backend && poetry run app
|
||||
```
|
||||
|
||||
## CoPilot Testing Gotcha
|
||||
|
||||
SDK mode spawns a Claude subprocess — **won't work inside Claude Code**. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v7
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: classic/
|
||||
file: classic/Dockerfile.autogpt
|
||||
|
||||
@@ -47,7 +47,7 @@ jobs:
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v7
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: classic/
|
||||
file: classic/Dockerfile.autogpt
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v7
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: classic/
|
||||
file: classic/Dockerfile.autogpt
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v7
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: classic/
|
||||
file: Dockerfile.autogpt
|
||||
|
||||
61
.github/workflows/platform-frontend-ci.yml
vendored
61
.github/workflows/platform-frontend-ci.yml
vendored
@@ -83,6 +83,65 @@ jobs:
|
||||
- name: Run lint
|
||||
run: pnpm lint
|
||||
|
||||
react-doctor:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run React Doctor
|
||||
id: react-doctor
|
||||
continue-on-error: true
|
||||
run: |
|
||||
OUTPUT=$(pnpm react-doctor:diff 2>&1) || true
|
||||
echo "$OUTPUT"
|
||||
SCORE=$(echo "$OUTPUT" | grep -oP '\d+(?= / 100)' | head -1)
|
||||
echo "score=${SCORE:-0}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Check React Doctor score
|
||||
env:
|
||||
RD_SCORE: ${{ steps.react-doctor.outputs.score }}
|
||||
MIN_SCORE: "90"
|
||||
run: |
|
||||
echo "React Doctor score: ${RD_SCORE}/100 (minimum: ${MIN_SCORE})"
|
||||
if [ "${RD_SCORE}" -lt "${MIN_SCORE}" ]; then
|
||||
echo "::error::React Doctor score ${RD_SCORE} is below the minimum threshold of ${MIN_SCORE}."
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " React Doctor score too low!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "To fix these issues, run Claude Code locally:"
|
||||
echo ""
|
||||
echo " cd autogpt_platform/frontend"
|
||||
echo " claude"
|
||||
echo ""
|
||||
echo "Then ask Claude to run react-doctor and fix the issues."
|
||||
echo "You can also run it manually:"
|
||||
echo ""
|
||||
echo " pnpm react-doctor # scan all files"
|
||||
echo " pnpm react-doctor:diff # scan only changed files"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
@@ -149,7 +208,7 @@ jobs:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -180,6 +180,4 @@ autogpt_platform/backend/settings.py
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.next
|
||||
@@ -1,10 +1,3 @@
|
||||
default_install_hook_types:
|
||||
- pre-commit
|
||||
- pre-push
|
||||
- post-checkout
|
||||
|
||||
default_stages: [pre-commit]
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
@@ -24,7 +17,6 @@ repos:
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
|
||||
- repo: local
|
||||
@@ -34,106 +26,49 @@ repos:
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Backend
|
||||
alias: poetry-install-platform-backend
|
||||
entry: poetry -C autogpt_platform/backend install
|
||||
# include autogpt_libs source (since it's a path dependency)
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
|
||||
poetry -C autogpt_platform/backend install
|
||||
'
|
||||
always_run: true
|
||||
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Libs
|
||||
alias: poetry-install-platform-libs
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
|
||||
poetry -C autogpt_platform/autogpt_libs install
|
||||
'
|
||||
always_run: true
|
||||
entry: poetry -C autogpt_platform/autogpt_libs install
|
||||
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: pnpm-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Frontend
|
||||
alias: pnpm-install-platform-frontend
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
|
||||
pnpm --prefix autogpt_platform/frontend install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
entry: poetry -C classic/original_autogpt install
|
||||
# include forge source (since it's a path dependency)
|
||||
always_run: true
|
||||
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
entry: poetry -C classic/forge install
|
||||
files: ^classic/forge/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
'
|
||||
always_run: true
|
||||
entry: poetry -C classic/benchmark install
|
||||
files: ^classic/benchmark/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, Prisma client must be up-to-date.
|
||||
@@ -141,54 +76,12 @@ repos:
|
||||
- id: prisma-generate
|
||||
name: Prisma Generate - AutoGPT Platform - Backend
|
||||
alias: prisma-generate-platform-backend
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
|
||||
cd autogpt_platform/backend
|
||||
&& poetry run prisma generate
|
||||
&& poetry run gen-prisma-stub
|
||||
'
|
||||
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
|
||||
# include everything that triggers poetry install + the prisma schema
|
||||
always_run: true
|
||||
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: export-api-schema
|
||||
name: Export API schema - AutoGPT Platform - Backend -> Frontend
|
||||
alias: export-api-schema-platform
|
||||
entry: >
|
||||
bash -c '
|
||||
cd autogpt_platform/backend
|
||||
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
&& cd ../frontend
|
||||
&& pnpm prettier --write ./src/app/api/openapi.json
|
||||
'
|
||||
files: ^autogpt_platform/backend/
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: generate-api-client
|
||||
name: Generate API client - AutoGPT Platform - Frontend
|
||||
alias: generate-api-client-platform-frontend
|
||||
entry: >
|
||||
bash -c '
|
||||
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
|
||||
else
|
||||
git diff --quiet HEAD -- "$SCHEMA" && exit 0
|
||||
fi;
|
||||
cd autogpt_platform/frontend && pnpm generate:api
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.7.2
|
||||
|
||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
.application.logs
|
||||
*.ign.*
|
||||
@@ -190,8 +190,5 @@ ZEROBOUNCE_API_KEY=
|
||||
POSTHOG_API_KEY=
|
||||
POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# Tally Form Integration (pre-populate business understanding on signup)
|
||||
TALLY_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
||||
# for the bash_exec MCP tool.
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
@@ -111,29 +111,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma and agent-browser.
|
||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||
# proper symlinks so npm/npx can find their modules.
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||
fonts-liberation libfontconfig1 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& npm install -g agent-browser \
|
||||
&& agent-browser install \
|
||||
&& rm -rf /tmp/* /root/.npm
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
|
||||
@@ -88,23 +88,20 @@ async def require_auth(
|
||||
)
|
||||
|
||||
|
||||
def require_permission(*permissions: APIKeyPermission):
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking required permissions.
|
||||
All listed permissions must be present.
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permissions(
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
missing = [p for p in permissions if p not in auth.scopes]
|
||||
if missing:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission(s): "
|
||||
f"{', '.join(p.value for p in missing)}",
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permissions
|
||||
return check_permission
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Security
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -9,17 +9,15 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.blocks
|
||||
from backend.api.external.middleware import require_auth, require_permission
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
@@ -97,43 +95,6 @@ async def execute_graph_block(
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs",
|
||||
tags=["graphs"],
|
||||
status_code=201,
|
||||
dependencies=[
|
||||
Security(
|
||||
require_permission(
|
||||
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
async def create_graph(
|
||||
graph: graph_db.Graph,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
|
||||
),
|
||||
) -> graph_db.GraphModel:
|
||||
"""
|
||||
Create a new agent graph.
|
||||
|
||||
The graph will be validated and assigned a new ID.
|
||||
It is automatically added to the user's library.
|
||||
"""
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
graph_model = graph_db.make_graph_model(graph, auth.user_id)
|
||||
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
||||
graph_model.validate_graph(for_run=False)
|
||||
|
||||
await graph_db.create_graph(graph_model, user_id=auth.user_id)
|
||||
await library_db.create_library_agent(graph_model, auth.user_id)
|
||||
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
|
||||
|
||||
return activated_graph
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
@@ -231,13 +192,13 @@ async def get_graph_execution_results(
|
||||
@v1_router.get(
|
||||
path="/store/agents",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
@@ -279,7 +240,7 @@ async def get_store_agents(
|
||||
@v1_router.get(
|
||||
path="/store/agents/{username}/{agent_name}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(
|
||||
@@ -307,13 +268,13 @@ async def get_store_agent(
|
||||
@v1_router.get(
|
||||
path="/store/creators",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.CreatorsResponse:
|
||||
@@ -349,7 +310,7 @@ async def get_store_creators(
|
||||
@v1_router.get(
|
||||
path="/store/creators/{username}",
|
||||
tags=["store"],
|
||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_store_creator(
|
||||
|
||||
@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||
)
|
||||
async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||
):
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
|
||||
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
Paginated listings with their versions
|
||||
StoreListingsWithVersionsResponse with listings and their versions
|
||||
"""
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
try:
|
||||
listings = await store_db.get_admin_listings_with_versions(
|
||||
status=status,
|
||||
search_query=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception as e:
|
||||
logger.exception("Error getting admin listings with versions: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving listings with versions"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def review_submission(
|
||||
store_listing_version_id: str,
|
||||
request: store_model.ReviewSubmissionRequest,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmissionAdminView:
|
||||
):
|
||||
"""
|
||||
Review a store listing submission.
|
||||
|
||||
@@ -73,24 +84,31 @@ async def review_submission(
|
||||
user_id: Authenticated admin user performing the review
|
||||
|
||||
Returns:
|
||||
StoreSubmissionAdminView with updated review information
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
try:
|
||||
already_approved = await store_db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
submission = await store_db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
external_comments=request.comments,
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches whenever approval state changes, since store visibility can change
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while reviewing the submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
from typing import Sequence
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -21,6 +19,7 @@ from backend.blocks._base import (
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -43,16 +42,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||
|
||||
# Boost blocks over marketplace agents in search results
|
||||
BLOCK_SCORE_BOOST = 50.0
|
||||
|
||||
# Block IDs to exclude from search results
|
||||
EXCLUDED_BLOCK_IDS = frozenset(
|
||||
{
|
||||
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
|
||||
}
|
||||
)
|
||||
|
||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||
|
||||
|
||||
@@ -75,8 +64,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
# Skip disabled and excluded blocks
|
||||
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't have categories (all should have at least one)
|
||||
if not block.categories:
|
||||
@@ -127,9 +116,6 @@ def get_blocks(
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip excluded blocks
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Skip blocks that don't match the category
|
||||
if category and category not in {c.name.lower() for c in block.categories}:
|
||||
continue
|
||||
@@ -269,25 +255,14 @@ async def _build_cached_search_results(
|
||||
"my_agents": 0,
|
||||
}
|
||||
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
elif include_blocks or include_integrations:
|
||||
# No query - list all blocks using in-memory approach
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
normalized_query=normalized_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
|
||||
if include_library_agents:
|
||||
library_response = await library_db.list_library_agents(
|
||||
@@ -332,14 +307,10 @@ async def _build_cached_search_results(
|
||||
|
||||
def _collect_block_results(
|
||||
*,
|
||||
normalized_query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Collect all blocks for listing (no search query).
|
||||
|
||||
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
@@ -352,10 +323,6 @@ def _collect_block_results(
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Skip excluded blocks
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
block_info = block.get_info()
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
is_integration = len(credentials) > 0
|
||||
@@ -365,6 +332,10 @@ def _collect_block_results(
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
score = _score_block(block, block_info, normalized_query)
|
||||
if not _should_include_item(score, normalized_query):
|
||||
continue
|
||||
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
@@ -375,122 +346,8 @@ def _collect_block_results(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=BLOCK_SCORE_BOOST,
|
||||
sort_key=block_info.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
score=score,
|
||||
sort_key=_get_item_name(block_info),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -615,8 +472,6 @@ async def _get_static_counts():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
all_blocks += 1
|
||||
|
||||
@@ -643,25 +498,47 @@ async def _get_static_counts():
|
||||
}
|
||||
|
||||
|
||||
def _contains_type(annotation: Any, target: type) -> bool:
|
||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||
if annotation is target:
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return False
|
||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||
|
||||
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if _contains_type(field.annotation, LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _score_block(
|
||||
block: AnyBlockSchema,
|
||||
block_info: BlockInfo,
|
||||
normalized_query: str,
|
||||
) -> float:
|
||||
if not normalized_query:
|
||||
return 0.0
|
||||
|
||||
name = block_info.name.lower()
|
||||
description = block_info.description.lower()
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
category_text = " ".join(
|
||||
category.get("category", "").lower() for category in block_info.categories
|
||||
)
|
||||
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
for info in credentials_info
|
||||
for provider in info.provider
|
||||
]
|
||||
provider_text = " ".join(provider_names)
|
||||
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||
|
||||
if _matches_llm_model(block.input_schema, normalized_query):
|
||||
score += 20
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _score_library_agent(
|
||||
agent: library_model.LibraryAgent,
|
||||
normalized_query: str,
|
||||
@@ -768,20 +645,31 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
return providers
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600, shared_cache=True)
|
||||
@cached(ttl_seconds=3600)
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
"""Return the most-executed blocks from the last 14 days.
|
||||
suggested_blocks = []
|
||||
# Sum the number of executions for each block type
|
||||
# Prisma cannot group by nested relations, so we do a raw query
|
||||
# Calculate the cutoff timestamp
|
||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
|
||||
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
|
||||
and returns the top `count` blocks sorted by execution count, excluding
|
||||
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
|
||||
"""
|
||||
results = await mv_suggested_blocks.prisma().find_many()
|
||||
results = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
agent_node."agentBlockId" AS block_id,
|
||||
COUNT(execution.id) AS execution_count
|
||||
FROM {schema_prefix}"AgentNodeExecution" execution
|
||||
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||
WHERE execution."endedTime" >= $1::timestamp
|
||||
GROUP BY agent_node."agentBlockId"
|
||||
ORDER BY execution_count DESC;
|
||||
""",
|
||||
timestamp_threshold,
|
||||
)
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input, Output, Agent, and excluded blocks
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
execution_counts = {row.block_id: row.execution_count for row in results}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
@@ -791,9 +679,11 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
BlockType.AGENT,
|
||||
):
|
||||
continue
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
execution_count = execution_counts.get(block.id, 0)
|
||||
# Find the execution count for this block
|
||||
execution_count = next(
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[SearchEntry]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockInfo]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Annotated, Sequence, cast, get_args
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
|
||||
from . import db as builder_db
|
||||
from . import model as builder_model
|
||||
|
||||
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
@@ -51,6 +49,11 @@ async def get_suggestions(
|
||||
Get all suggestions for the Blocks Menu.
|
||||
"""
|
||||
return builder_model.SuggestionsResponse(
|
||||
otto_suggestions=[
|
||||
"What blocks do I need to get started?",
|
||||
"Help me create a list",
|
||||
"Help me feed my data to Google Maps",
|
||||
],
|
||||
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||
providers=[
|
||||
ProviderName.TWITTER,
|
||||
@@ -148,7 +151,7 @@ async def get_providers(
|
||||
async def search(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
@@ -157,20 +160,9 @@ async def search(
|
||||
"""
|
||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||
"""
|
||||
# Parse and validate filter parameter
|
||||
filters: list[builder_model.FilterType]
|
||||
if filter:
|
||||
filter_values = [f.strip() for f in filter.split(",")]
|
||||
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
|
||||
if invalid_filters:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
|
||||
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
|
||||
)
|
||||
filters = cast(list[builder_model.FilterType], filter_values)
|
||||
else:
|
||||
filters = [
|
||||
# If no filters are provided, then we will return all types
|
||||
if not filter:
|
||||
filter = [
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
@@ -182,7 +174,7 @@ async def search(
|
||||
cached_results = await builder_db.get_sorted_search_results(
|
||||
user_id=user_id,
|
||||
search_query=search_query,
|
||||
filters=filters,
|
||||
filters=filter,
|
||||
by_creator=by_creator,
|
||||
)
|
||||
|
||||
@@ -204,7 +196,7 @@ async def search(
|
||||
user_id,
|
||||
builder_model.SearchEntry(
|
||||
search_query=search_query,
|
||||
filter=filters,
|
||||
filter=filter,
|
||||
by_creator=by_creator,
|
||||
search_id=search_id,
|
||||
),
|
||||
|
||||
@@ -2,21 +2,23 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -25,10 +27,8 @@ from backend.copilot.model import (
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
@@ -44,23 +44,20 @@ from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
MCPToolOutputResponse,
|
||||
MCPToolsDiscoveredResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,9 +86,6 @@ class StreamChatRequest(BaseModel):
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -105,8 +99,10 @@ class CreateSessionResponse(BaseModel):
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
turn_id: str
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -136,25 +132,20 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class CancelSessionResponse(BaseModel):
|
||||
"""Response model for the cancel session endpoint."""
|
||||
class CancelTaskResponse(BaseModel):
|
||||
"""Response model for the cancel task endpoint."""
|
||||
|
||||
cancelled: bool
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class UpdateSessionTitleRequest(BaseModel):
|
||||
"""Request model for updating a session's title."""
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
title: str
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_must_not_be_blank(cls, v: str) -> str:
|
||||
stripped = v.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Title must not be blank")
|
||||
return stripped
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
@@ -265,58 +256,9 @@ async def delete_session(
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
|
||||
# Best-effort cleanup of the E2B sandbox (if any).
|
||||
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
||||
e2b_cfg = ChatConfig()
|
||||
if e2b_cfg.e2b_active:
|
||||
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
||||
try:
|
||||
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def update_session_title_route(
|
||||
session_id: str,
|
||||
request: UpdateSessionTitleRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Update the title of a chat session.
|
||||
|
||||
Allows the user to rename their chat session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
request: Request body containing the new title.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
dict: Status of the update.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
success = await update_session_title(session_id, user_id, request.title)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -328,7 +270,7 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
@@ -346,21 +288,28 @@ async def get_session(
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
@@ -380,7 +329,7 @@ async def get_session(
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
) -> CancelTaskResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||
@@ -389,33 +338,39 @@ async def cancel_session_task(
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if not active_session:
|
||||
return CancelSessionResponse(cancelled=True, reason="no_active_session")
|
||||
active_task, _ = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if not active_task:
|
||||
return CancelTaskResponse(cancelled=False, reason="no_active_task")
|
||||
|
||||
await enqueue_cancel_task(session_id)
|
||||
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
||||
task_id = active_task.task_id
|
||||
await enqueue_cancel_task(task_id)
|
||||
logger.info(
|
||||
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
|
||||
f"session ...{session_id[-8:]}"
|
||||
)
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
# Keep max_wait below typical reverse-proxy read timeouts.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
session_state = await stream_registry.get_session(session_id)
|
||||
if session_state is None or session_state.status != "running":
|
||||
task = await stream_registry.get_task(task_id)
|
||||
if task is None or task.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
||||
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
|
||||
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
|
||||
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
return CancelTaskResponse(cancelled=True, task_id=task_id)
|
||||
|
||||
logger.warning(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
|
||||
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
|
||||
return CancelTaskResponse(
|
||||
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
|
||||
)
|
||||
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -435,15 +390,16 @@ async def stream_chat_post(
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
@@ -470,38 +426,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
@@ -522,38 +446,37 @@ async def stream_chat_post(
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
@@ -568,7 +491,7 @@ async def stream_chat_post(
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -576,12 +499,11 @@ async def stream_chat_post(
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=subscribe_from_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -596,7 +518,7 @@ async def stream_chat_post(
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
@@ -664,19 +586,19 @@ async def stream_chat_post(
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -723,21 +645,17 @@ async def resume_session_stream(
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_session:
|
||||
if not active_task:
|
||||
return Response(status_code=204)
|
||||
|
||||
# Always replay from the beginning ("0-0") on resume.
|
||||
# We can't use last_message_id because it's the latest ID in the backend
|
||||
# stream, not the latest the frontend received — the gap causes lost
|
||||
# messages. The frontend deduplicates replayed content.
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=active_task.task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0",
|
||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -749,7 +667,7 @@ async def resume_session_stream(
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
@@ -773,12 +691,12 @@ async def resume_session_stream(
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
active_task.task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
@@ -806,6 +724,7 @@ async def resume_session_stream(
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
@@ -828,6 +747,229 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@@ -908,8 +1050,9 @@ ToolResponseUnion = (
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| MCPToolsDiscoveredResponse
|
||||
| MCPToolOutputResponse
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
"""Tests for chat API routes: session title update and file attachment validation."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _mock_update_session_title(
|
||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||
):
|
||||
"""Mock update_session_title."""
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.update_session_title",
|
||||
new_callable=AsyncMock,
|
||||
return_value=success,
|
||||
)
|
||||
|
||||
|
||||
# ─── Update title: success ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_title_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "My project"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||
|
||||
|
||||
def test_update_title_trims_whitespace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_update = _mock_update_session_title(mocker, success=True)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " trimmed "},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||
|
||||
|
||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||
|
||||
|
||||
def test_update_title_blank_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_update_title_empty_rejected(
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": ""},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||
|
||||
|
||||
def test_update_title_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
_mock_update_session_title(mocker, success=False)
|
||||
|
||||
response = client.patch(
|
||||
"/sessions/sess-1/title",
|
||||
json={"title": "New name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
},
|
||||
)
|
||||
# Should get past validation — 200 streaming response expected
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [
|
||||
valid_id,
|
||||
"not-a-uuid",
|
||||
"../../../etc/passwd",
|
||||
"",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
@@ -22,7 +22,6 @@ from backend.data.human_review import (
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -322,13 +321,10 @@ async def process_review_action(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import prisma.enums
|
||||
import prisma.models
|
||||
import pytest
|
||||
|
||||
import backend.api.features.store.exceptions
|
||||
from backend.data.db import connect
|
||||
from backend.data.includes import library_agent_include
|
||||
|
||||
@@ -143,7 +144,6 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
@@ -178,6 +178,7 @@ async def test_add_agent_to_library(mocker):
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(db.NotFoundError):
|
||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
class FolderValidationError(Exception):
|
||||
"""Raised when folder operations fail validation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FolderAlreadyExistsError(FolderValidationError):
|
||||
"""Raised when a folder with the same name already exists in the location."""
|
||||
|
||||
pass
|
||||
@@ -26,95 +26,6 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
# === Folder Models ===
|
||||
|
||||
|
||||
class LibraryFolder(pydantic.BaseModel):
|
||||
"""Represents a folder for organizing library agents."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
parent_id: str | None = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
agent_count: int = 0 # Direct agents in folder
|
||||
subfolder_count: int = 0 # Direct child folders
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
folder: prisma.models.LibraryFolder,
|
||||
agent_count: int = 0,
|
||||
subfolder_count: int = 0,
|
||||
) -> "LibraryFolder":
|
||||
"""Factory method that constructs a LibraryFolder from a Prisma model."""
|
||||
return LibraryFolder(
|
||||
id=folder.id,
|
||||
user_id=folder.userId,
|
||||
name=folder.name,
|
||||
icon=folder.icon,
|
||||
color=folder.color,
|
||||
parent_id=folder.parentId,
|
||||
created_at=folder.createdAt,
|
||||
updated_at=folder.updatedAt,
|
||||
agent_count=agent_count,
|
||||
subfolder_count=subfolder_count,
|
||||
)
|
||||
|
||||
|
||||
class LibraryFolderTree(LibraryFolder):
|
||||
"""Folder with nested children for tree view."""
|
||||
|
||||
children: list["LibraryFolderTree"] = []
|
||||
|
||||
|
||||
class FolderCreateRequest(pydantic.BaseModel):
|
||||
"""Request model for creating a folder."""
|
||||
|
||||
name: str = pydantic.Field(..., min_length=1, max_length=100)
|
||||
icon: str | None = None
|
||||
color: str | None = pydantic.Field(
|
||||
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
|
||||
)
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class FolderUpdateRequest(pydantic.BaseModel):
|
||||
"""Request model for updating a folder."""
|
||||
|
||||
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
|
||||
|
||||
class FolderMoveRequest(pydantic.BaseModel):
|
||||
"""Request model for moving a folder to a new parent."""
|
||||
|
||||
target_parent_id: str | None = None # None = move to root
|
||||
|
||||
|
||||
class BulkMoveAgentsRequest(pydantic.BaseModel):
|
||||
"""Request model for moving multiple agents to a folder."""
|
||||
|
||||
agent_ids: list[str]
|
||||
folder_id: str | None = None # None = move to root
|
||||
|
||||
|
||||
class FolderListResponse(pydantic.BaseModel):
|
||||
"""Response schema for a list of folders."""
|
||||
|
||||
folders: list[LibraryFolder]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class FolderTreeResponse(pydantic.BaseModel):
|
||||
"""Response schema for folder tree structure."""
|
||||
|
||||
tree: list[LibraryFolderTree]
|
||||
|
||||
|
||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||
"""Creator information for a marketplace listing."""
|
||||
|
||||
@@ -209,9 +120,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
folder_name: str | None = None # Denormalized for display
|
||||
|
||||
recommended_schedule_cron: str | None = None
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
@@ -351,8 +259,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
folder_id=agent.folderId,
|
||||
folder_name=agent.Folder.name if agent.Folder else None,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
@@ -564,7 +470,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
settings: Optional[GraphSettings] = pydantic.Field(
|
||||
default=None, description="User-specific settings for this library agent"
|
||||
)
|
||||
folder_id: Optional[str] = pydantic.Field(
|
||||
default=None,
|
||||
description="Folder ID to move agent to (None to move to root)",
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import fastapi
|
||||
|
||||
from .agents import router as agents_router
|
||||
from .folders import router as folders_router
|
||||
from .presets import router as presets_router
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
router.include_router(presets_router)
|
||||
router.include_router(folders_router)
|
||||
router.include_router(agents_router)
|
||||
|
||||
@@ -41,14 +41,6 @@ async def list_library_agents(
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
folder_id: Optional[str] = Query(
|
||||
None,
|
||||
description="Filter by folder ID",
|
||||
),
|
||||
include_root_only: bool = Query(
|
||||
False,
|
||||
description="Only return agents without a folder (root-level agents)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all agents in the user's library (both created and saved).
|
||||
@@ -59,8 +51,6 @@ async def list_library_agents(
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
folder_id=folder_id,
|
||||
include_root_only=include_root_only,
|
||||
)
|
||||
|
||||
|
||||
@@ -178,7 +168,6 @@ async def update_library_agent(
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
folder_id=payload.folder_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/folders",
|
||||
tags=["library", "folders", "private"],
|
||||
dependencies=[Security(autogpt_auth_lib.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="List Library Folders",
|
||||
response_model=library_model.FolderListResponse,
|
||||
responses={
|
||||
200: {"description": "List of folders"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def list_folders(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
parent_id: Optional[str] = Query(
|
||||
None,
|
||||
description="Filter by parent folder ID. If not provided, returns root-level folders.",
|
||||
),
|
||||
include_relations: bool = Query(
|
||||
True,
|
||||
description="Include agent and subfolder relations (for counts)",
|
||||
),
|
||||
) -> library_model.FolderListResponse:
|
||||
"""
|
||||
List folders for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
parent_id: Optional parent folder ID to filter by.
|
||||
include_relations: Whether to include agent and subfolder relations for counts.
|
||||
|
||||
Returns:
|
||||
A FolderListResponse containing folders.
|
||||
"""
|
||||
folders = await library_db.list_folders(
|
||||
user_id=user_id,
|
||||
parent_id=parent_id,
|
||||
include_relations=include_relations,
|
||||
)
|
||||
return library_model.FolderListResponse(
|
||||
folders=folders,
|
||||
pagination=library_model.Pagination(
|
||||
total_items=len(folders),
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=len(folders),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tree",
|
||||
summary="Get Folder Tree",
|
||||
response_model=library_model.FolderTreeResponse,
|
||||
responses={
|
||||
200: {"description": "Folder tree structure"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def get_folder_tree(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.FolderTreeResponse:
|
||||
"""
|
||||
Get the full folder tree for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
A FolderTreeResponse containing the nested folder structure.
|
||||
"""
|
||||
tree = await library_db.get_folder_tree(user_id=user_id)
|
||||
return library_model.FolderTreeResponse(tree=tree)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{folder_id}",
|
||||
summary="Get Folder",
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
200: {"description": "Folder details"},
|
||||
404: {"description": "Folder not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def get_folder(
|
||||
folder_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Get a specific folder.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to retrieve.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The requested LibraryFolder.
|
||||
"""
|
||||
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Create Folder",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
201: {"description": "Folder created successfully"},
|
||||
400: {"description": "Validation error"},
|
||||
404: {"description": "Parent folder not found"},
|
||||
409: {"description": "Folder name conflict"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def create_folder(
|
||||
payload: library_model.FolderCreateRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Create a new folder.
|
||||
|
||||
Args:
|
||||
payload: The folder creation request.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The created LibraryFolder.
|
||||
"""
|
||||
return await library_db.create_folder(
|
||||
user_id=user_id,
|
||||
name=payload.name,
|
||||
parent_id=payload.parent_id,
|
||||
icon=payload.icon,
|
||||
color=payload.color,
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{folder_id}",
|
||||
summary="Update Folder",
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
200: {"description": "Folder updated successfully"},
|
||||
400: {"description": "Validation error"},
|
||||
404: {"description": "Folder not found"},
|
||||
409: {"description": "Folder name conflict"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def update_folder(
|
||||
folder_id: str,
|
||||
payload: library_model.FolderUpdateRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Update a folder's properties.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to update.
|
||||
payload: The folder update request.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The updated LibraryFolder.
|
||||
"""
|
||||
return await library_db.update_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
name=payload.name,
|
||||
icon=payload.icon,
|
||||
color=payload.color,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{folder_id}/move",
|
||||
summary="Move Folder",
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
200: {"description": "Folder moved successfully"},
|
||||
400: {"description": "Validation error (circular reference)"},
|
||||
404: {"description": "Folder or target parent not found"},
|
||||
409: {"description": "Folder name conflict in target location"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def move_folder(
|
||||
folder_id: str,
|
||||
payload: library_model.FolderMoveRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Move a folder to a new parent.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to move.
|
||||
payload: The move request with target parent.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The moved LibraryFolder.
|
||||
"""
|
||||
return await library_db.move_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
target_parent_id=payload.target_parent_id,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{folder_id}",
|
||||
summary="Delete Folder",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "Folder deleted successfully"},
|
||||
404: {"description": "Folder not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def delete_folder(
|
||||
folder_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> Response:
|
||||
"""
|
||||
Soft-delete a folder and all its contents.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to delete.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 No Content if successful.
|
||||
"""
|
||||
await library_db.delete_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
soft_delete=True,
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
# === Bulk Agent Operations ===
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/bulk-move",
|
||||
summary="Bulk Move Agents",
|
||||
response_model=list[library_model.LibraryAgent],
|
||||
responses={
|
||||
200: {"description": "Agents moved successfully"},
|
||||
404: {"description": "Folder not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def bulk_move_agents(
|
||||
payload: library_model.BulkMoveAgentsRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Move multiple agents to a folder.
|
||||
|
||||
Args:
|
||||
payload: The bulk move request with agent IDs and target folder.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgents.
|
||||
"""
|
||||
return await library_db.bulk_move_agents_to_folder(
|
||||
agent_ids=payload.agent_ids,
|
||||
folder_id=payload.folder_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -115,8 +115,6 @@ async def test_get_library_agents_success(
|
||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
||||
page=1,
|
||||
page_size=15,
|
||||
folder_id=None,
|
||||
include_root_only=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,24 +7,20 @@ frontend can list available tools on an MCP server before placing a block.
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import Security
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
server_host,
|
||||
)
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests, validate_url
|
||||
from backend.util.request import HTTPClientError, Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -78,20 +74,32 @@ async def discover_tools(
|
||||
If the user has a stored MCP credential for this server URL, it will be
|
||||
used automatically — no need to pass an explicit auth token.
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
auth_token = request.auth_token
|
||||
|
||||
# Auto-use stored MCP credential when no explicit token is provided.
|
||||
if not auth_token:
|
||||
best_cred = await auto_lookup_mcp_credential(
|
||||
user_id, normalize_mcp_url(request.server_url)
|
||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Find the freshest credential for this server URL
|
||||
best_cred: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
||||
):
|
||||
if best_cred is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best_cred.access_token_expires_at or 0)
|
||||
):
|
||||
best_cred = cred
|
||||
if best_cred:
|
||||
# Refresh the token if expired before using it
|
||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||
logger.info(
|
||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||
f"expires_at={best_cred.access_token_expires_at}"
|
||||
)
|
||||
auth_token = best_cred.access_token.get_secret_value()
|
||||
|
||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||
@@ -126,7 +134,7 @@ async def discover_tools(
|
||||
],
|
||||
server_name=(
|
||||
init_result.get("serverInfo", {}).get("name")
|
||||
or server_host(request.server_url)
|
||||
or urlparse(request.server_url).hostname
|
||||
or "MCP"
|
||||
),
|
||||
protocol_version=init_result.get("protocolVersion"),
|
||||
@@ -165,16 +173,7 @@ async def mcp_oauth_login(
|
||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||
4. Returns the authorization URL for the frontend to open in a popup
|
||||
"""
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
# Normalize the URL so that credentials stored here are matched consistently
|
||||
# by auto_lookup_mcp_credential (which also uses normalized URLs).
|
||||
server_url = normalize_mcp_url(request.server_url)
|
||||
client = MCPClient(server_url)
|
||||
client = MCPClient(request.server_url)
|
||||
|
||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||
protected_resource = await client.discover_auth()
|
||||
@@ -183,16 +182,7 @@ async def mcp_oauth_login(
|
||||
|
||||
if protected_resource and protected_resource.get("authorization_servers"):
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", server_url)
|
||||
|
||||
# Validate the auth server URL from metadata to prevent SSRF.
|
||||
try:
|
||||
await validate_url(auth_server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid authorization server URL in metadata: {e}",
|
||||
)
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
|
||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
@@ -202,7 +192,7 @@ async def mcp_oauth_login(
|
||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||
# the correct audience for the token (RFC 8707 resource is optional).
|
||||
resource_url = None
|
||||
metadata = await client.discover_auth_server_metadata(server_url)
|
||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||
|
||||
if (
|
||||
not metadata
|
||||
@@ -232,18 +222,12 @@ async def mcp_oauth_login(
|
||||
client_id = ""
|
||||
client_secret = ""
|
||||
if registration_endpoint:
|
||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||
try:
|
||||
await validate_url(registration_endpoint, trusted_origins=[])
|
||||
except ValueError:
|
||||
pass # Skip registration, fall back to default client_id
|
||||
else:
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, request.server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
|
||||
if not client_id:
|
||||
client_id = "autogpt-platform"
|
||||
@@ -261,7 +245,7 @@ async def mcp_oauth_login(
|
||||
"token_url": token_url,
|
||||
"revoke_url": revoke_url,
|
||||
"resource_url": resource_url,
|
||||
"server_url": server_url,
|
||||
"server_url": request.server_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
@@ -358,7 +342,7 @@ async def mcp_oauth_callback(
|
||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||
|
||||
hostname = server_host(meta["server_url"])
|
||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||
credentials.title = f"MCP: {hostname}"
|
||||
|
||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||
@@ -373,9 +357,7 @@ async def mcp_oauth_callback(
|
||||
):
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||
logger.info(
|
||||
"Removed old MCP credential %s for %s",
|
||||
old.id,
|
||||
server_host(meta["server_url"]),
|
||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||
@@ -393,93 +375,6 @@ async def mcp_oauth_callback(
|
||||
)
|
||||
|
||||
|
||||
# ======================== Bearer Token ======================== #
|
||||
|
||||
|
||||
class MCPStoreTokenRequest(BaseModel):
|
||||
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
|
||||
|
||||
server_url: str = Field(
|
||||
description="MCP server URL the token authenticates against"
|
||||
)
|
||||
token: SecretStr = Field(
|
||||
min_length=1, description="Bearer token / API key for the MCP server"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/token",
|
||||
summary="Store a bearer token for an MCP server",
|
||||
)
|
||||
async def mcp_store_token(
|
||||
request: MCPStoreTokenRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
"""
|
||||
Store a manually provided bearer token as an MCP credential.
|
||||
|
||||
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
|
||||
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
|
||||
``run_mcp_tool`` calls will automatically pick up the token via
|
||||
``_auto_lookup_credential``.
|
||||
"""
|
||||
token = request.token.get_secret_value().strip()
|
||||
if not token:
|
||||
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
|
||||
|
||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||
try:
|
||||
await validate_url(request.server_url, trusted_origins=[])
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||
|
||||
# Normalize URL so trailing-slash variants match existing credentials.
|
||||
server_url = normalize_mcp_url(request.server_url)
|
||||
hostname = server_host(server_url)
|
||||
|
||||
# Collect IDs of old credentials to clean up after successful create.
|
||||
old_cred_ids: list[str] = []
|
||||
try:
|
||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
old_cred_ids = [
|
||||
old.id
|
||||
for old in old_creds
|
||||
if isinstance(old, OAuth2Credentials)
|
||||
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
|
||||
== server_url
|
||||
]
|
||||
except Exception:
|
||||
logger.debug("Could not query old MCP token credentials", exc_info=True)
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
provider=ProviderName.MCP.value,
|
||||
title=f"MCP: {hostname}",
|
||||
access_token=SecretStr(token),
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": server_url},
|
||||
)
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
# Only delete old credentials after the new one is safely stored.
|
||||
for old_id in old_cred_ids:
|
||||
try:
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old_id)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP token credential", exc_info=True)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=hostname,
|
||||
)
|
||||
|
||||
|
||||
# ======================== Helpers ======================== #
|
||||
|
||||
|
||||
@@ -505,7 +400,5 @@ async def _register_mcp_client(
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Dynamic client registration failed for %s: %s", server_host(server_url), e
|
||||
)
|
||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||
return None
|
||||
|
||||
@@ -11,11 +11,9 @@ import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.mcp.routes import router
|
||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -30,16 +28,6 @@ async def client():
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_ssrf_validation():
|
||||
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_success(self, client):
|
||||
@@ -68,12 +56,9 @@ class TestDiscoverTools:
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={
|
||||
@@ -122,6 +107,10 @@ class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||
"""When no explicit token is given, stored MCP credentials are used."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
stored_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: example.com",
|
||||
@@ -135,12 +124,10 @@ class TestDiscoverTools:
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stored_cred,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
@@ -162,12 +149,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_mcp_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("Connection refused")
|
||||
@@ -185,12 +169,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_generic_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||
|
||||
@@ -206,12 +187,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_auth_required(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||
@@ -229,12 +207,9 @@ class TestDiscoverTools:
|
||||
async def test_discover_tools_forbidden(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||
@@ -356,6 +331,10 @@ class TestOAuthLogin:
|
||||
class TestOAuthCallback:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_success(self, client):
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
mock_creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title=None,
|
||||
@@ -455,118 +434,3 @@ class TestOAuthCallback:
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "token exchange failed" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestStoreToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_success(self, client):
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
mock_cm.create = AsyncMock()
|
||||
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": "my-api-key-123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["provider"] == "mcp"
|
||||
assert data["type"] == "oauth2"
|
||||
assert data["host"] == "mcp.example.com"
|
||||
mock_cm.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_blank_rejected(self, client):
|
||||
"""Blank token string (after stripping) should return 422."""
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": " ",
|
||||
},
|
||||
)
|
||||
# Pydantic min_length=1 catches the whitespace-only token
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_replaces_old_credential(self, client):
|
||||
old_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: mcp.example.com",
|
||||
access_token=SecretStr("old-token"),
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
|
||||
mock_cm.create = AsyncMock()
|
||||
mock_cm.store.delete_creds_by_id = AsyncMock()
|
||||
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"token": "new-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cm.store.delete_creds_by_id.assert_called_once_with(
|
||||
"test-user-id", old_cred.id
|
||||
)
|
||||
|
||||
|
||||
class TestSSRFValidation:
|
||||
"""Verify that validate_url is enforced on all endpoints."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "http://localhost/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked loopback" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked private IP"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "http://10.0.0.1/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked private ip" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_token_ssrf_blocked(self, client):
|
||||
with patch(
|
||||
"backend.api.features.mcp.routes.validate_url",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError("blocked loopback"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/token",
|
||||
json={
|
||||
"server_url": "http://127.0.0.1/mcp",
|
||||
"token": "some-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "blocked loopback" in response.json()["detail"].lower()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
from . import db as store_db
|
||||
@@ -21,7 +23,7 @@ def clear_all_caches():
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
@@ -55,7 +57,7 @@ async def _get_cached_agent_details(
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
@@ -73,4 +75,4 @@ async def _get_cached_store_creators(
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await store_db.get_store_creator(username=username.lower())
|
||||
return await store_db.get_store_creator_details(username=username.lower())
|
||||
|
||||
@@ -9,26 +9,15 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _contains_type(annotation: Any, target: type) -> bool:
|
||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||
if annotation is target:
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return False
|
||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
@@ -199,51 +188,45 @@ class BlockHandler(ContentHandler):
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Skip disabled blocks - they shouldn't be indexed
|
||||
if block_instance.disabled:
|
||||
continue
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if block_instance.name:
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if block_instance.description:
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if block_instance.categories:
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input schema field descriptions
|
||||
block_input_fields = block_instance.input_schema.model_fields
|
||||
parts += [
|
||||
f"{field_name}: {field_info.description}"
|
||||
for field_name, field_info in block_input_fields.items()
|
||||
if field_info.description
|
||||
]
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in block_instance.categories]
|
||||
if block_instance.categories
|
||||
else []
|
||||
)
|
||||
|
||||
# Extract provider names from credentials fields
|
||||
credentials_info = (
|
||||
block_instance.input_schema.get_credentials_fields_info()
|
||||
)
|
||||
is_integration = len(credentials_info) > 0
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
for info in credentials_info.values()
|
||||
for provider in info.provider
|
||||
]
|
||||
|
||||
# Check if block has LlmModel field in input schema
|
||||
has_llm_model_field = any(
|
||||
_contains_type(field.annotation, LlmModel)
|
||||
for field in block_instance.input_schema.model_fields.values()
|
||||
[cat.value for cat in categories] if categories else []
|
||||
)
|
||||
|
||||
items.append(
|
||||
@@ -252,11 +235,8 @@ class BlockHandler(ContentHandler):
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": block_instance.name,
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": categories_list,
|
||||
"providers": provider_names,
|
||||
"has_llm_model_field": has_llm_model_field,
|
||||
"is_integration": is_integration,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
|
||||
@@ -82,10 +82,9 @@ async def test_block_handler_get_missing_items(mocker):
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_field = MagicMock()
|
||||
mock_field.description = "Math expression to evaluate"
|
||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
@@ -310,19 +309,19 @@ async def test_content_handlers_registry():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with empty values (all attributes exist but are falsy)
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.description = ""
|
||||
mock_block_instance.categories = set()
|
||||
mock_block_instance.input_schema.model_fields = {}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
@@ -353,8 +352,6 @@ async def test_block_handler_skips_failed_blocks():
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_instance.input_schema.model_fields = {}
|
||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
listing_version_id="version123",
|
||||
storeListingVersionId="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
use_for_onboarding=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data - StoreAgent view already contains the active version data
|
||||
# Mock data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
listing_version_id="version123",
|
||||
storeListingVersionId="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
@@ -85,38 +85,102 @@ async def test_get_store_agent_details(mocker):
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
graph_id="test-graph-id",
|
||||
graph_versions=["1"],
|
||||
agentGraphVersions=["1"],
|
||||
agentGraphId="test-graph-id",
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
use_for_onboarding=False,
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock StoreAgent prisma call
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id-active",
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results - constructed from the StoreAgent view
|
||||
# Verify results - should use active version data
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.active_version_id == "version123"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.has_approved_version is True
|
||||
assert result.store_listing_version_id == "version123"
|
||||
assert result.graph_id == "test-graph-id"
|
||||
assert result.runs == 10
|
||||
assert result.rating == 4.5
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
|
||||
# Verify single StoreAgent lookup
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator(mocker):
|
||||
async def test_get_store_creator_details(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
@@ -138,7 +202,7 @@ async def test_get_store_creator(mocker):
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator("creator")
|
||||
result = await db.get_store_creator_details("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
@@ -154,110 +218,61 @@ async def test_get_store_creator(mocker):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
now = datetime.now()
|
||||
|
||||
# Mock agent graph (with no pending submissions) and user with profile
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
userId="user-id",
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test",
|
||||
isFeatured=False,
|
||||
links=[],
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
)
|
||||
mock_user = prisma.models.User(
|
||||
id="user-id",
|
||||
email="test@example.com",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
Profile=[mock_profile],
|
||||
emailVerified=True,
|
||||
metadata="{}", # type: ignore[reportArgumentType]
|
||||
integrations="",
|
||||
maxEmailsPerDay=1,
|
||||
notifyOnAgentRun=True,
|
||||
notifyOnZeroBalance=True,
|
||||
notifyOnLowBalance=True,
|
||||
notifyOnBlockExecutionFailed=True,
|
||||
notifyOnContinuousAgentError=True,
|
||||
notifyOnDailySummary=True,
|
||||
notifyOnWeeklySummary=True,
|
||||
notifyOnMonthlySummary=True,
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
)
|
||||
# Mock data
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=now,
|
||||
createdAt=datetime.now(),
|
||||
isActive=True,
|
||||
StoreListingVersions=[],
|
||||
User=mock_user,
|
||||
)
|
||||
|
||||
# Mock the created StoreListingVersion (returned by create)
|
||||
mock_store_listing_obj = prisma.models.StoreListing(
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isDeleted=False,
|
||||
hasApprovedVersion=False,
|
||||
slug="test-agent",
|
||||
agentGraphId="agent-id",
|
||||
owningUserId="user-id",
|
||||
useForOnboarding=False,
|
||||
)
|
||||
mock_version = prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
subHeading="",
|
||||
imageUrls=[],
|
||||
categories=[],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
submittedAt=now,
|
||||
StoreListing=mock_store_listing_obj,
|
||||
owningUserId="user-id",
|
||||
Versions=[
|
||||
prisma.models.StoreListingVersion(
|
||||
id="version-id",
|
||||
agentGraphId="agent-id",
|
||||
agentGraphVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
subHeading="Test heading",
|
||||
imageUrls=["image.jpg"],
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
version=1,
|
||||
storeListingId="listing-id",
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock transaction context manager
|
||||
mock_tx = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.store.db.transaction",
|
||||
return_value=mocker.AsyncMock(
|
||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||
__aexit__=mocker.AsyncMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
|
||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
graph_id="agent-id",
|
||||
graph_version=1,
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
@@ -266,11 +281,11 @@ async def test_create_store_submission(mocker):
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
assert result.listing_version_id == "version-id"
|
||||
assert result.store_listing_version_id == "version-id"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_slv.return_value.create.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -303,6 +318,7 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
@@ -373,7 +389,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by=db.StoreAgentsSortOptions.RATING,
|
||||
sorted_by="rating",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
@@ -57,6 +57,12 @@ class StoreError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId" as "storeListingVersionId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
||||
SELECT uce."contentId", uce.embedding
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON uce."contentId" = sa.listing_version_id
|
||||
ON uce."contentId" = sa."storeListingVersionId"
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa.graph_id,
|
||||
sa."agentGraphId",
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa.listing_version_id
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa.listing_version_id = uce."contentId"
|
||||
ON sa."storeListingVersionId" = uce."contentId"
|
||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_vals AS (
|
||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
graph_id,
|
||||
"agentGraphId",
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, List, Self
|
||||
from typing import List
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import prisma.models
|
||||
|
||||
|
||||
class ChangelogEntry(pydantic.BaseModel):
|
||||
version: str
|
||||
@@ -16,9 +13,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
||||
date: datetime.datetime
|
||||
|
||||
|
||||
class MyUnpublishedAgent(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_name: str
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
@@ -26,8 +23,8 @@ class MyUnpublishedAgent(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyUnpublishedAgent]
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -43,21 +40,6 @@ class StoreAgent(pydantic.BaseModel):
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
||||
return cls(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.graph_id,
|
||||
)
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
@@ -80,192 +62,81 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
graph_id: str
|
||||
graph_versions: list[str]
|
||||
agentGraphVersions: list[str]
|
||||
agentGraphId: str
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str
|
||||
has_approved_version: bool
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
|
||||
# Optional changelog data when include_changelog=True
|
||||
changelog: list[ChangelogEntry] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
||||
return cls(
|
||||
store_listing_version_id=agent.listing_version_id,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_output_demo=agent.agent_output_demo or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username or "",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
graph_id=agent.graph_id,
|
||||
graph_versions=agent.graph_versions,
|
||||
last_updated=agent.updated_at,
|
||||
recommended_schedule_cron=agent.recommended_schedule_cron,
|
||||
active_version_id=agent.listing_version_id,
|
||||
has_approved_version=True, # StoreAgent view only has approved agents
|
||||
)
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
"""Marketplace user profile (only attributes that the user can update)"""
|
||||
|
||||
username: str
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str | None
|
||||
links: list[str]
|
||||
|
||||
|
||||
class ProfileDetails(Profile):
|
||||
"""Marketplace user profile (including read-only fields)"""
|
||||
|
||||
is_featured: bool
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
||||
return cls(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
avatar_url=profile.avatarUrl,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
is_featured=profile.isFeatured,
|
||||
)
|
||||
|
||||
|
||||
class CreatorDetails(ProfileDetails):
|
||||
"""Marketplace creator profile details, including aggregated stats"""
|
||||
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_runs: int
|
||||
agent_rating: float
|
||||
top_categories: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
||||
return cls(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
avatar_url=creator.avatar_url,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
is_featured=creator.is_featured,
|
||||
num_agents=creator.num_agents,
|
||||
agent_runs=creator.agent_runs,
|
||||
agent_rating=creator.agent_rating,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[CreatorDetails]
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
# From StoreListing:
|
||||
listing_id: str
|
||||
user_id: str
|
||||
slug: str
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
|
||||
# From StoreListingVersion:
|
||||
listing_version_id: str
|
||||
listing_version: int
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None
|
||||
categories: list[str]
|
||||
instructions: str | None = None
|
||||
image_urls: list[str]
|
||||
video_url: str | None
|
||||
agent_output_demo_url: str | None
|
||||
|
||||
submitted_at: datetime.datetime | None
|
||||
changes_summary: str | None
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
version: int | None = None # Actual version number from the database
|
||||
|
||||
reviewer_id: str | None = None
|
||||
review_comments: str | None = None # External comments visible to creator
|
||||
internal_comments: str | None = None # Private notes for admin use only
|
||||
reviewed_at: datetime.datetime | None = None
|
||||
changes_summary: str | None = None
|
||||
|
||||
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||
run_count: int = 0
|
||||
review_count: int = 0
|
||||
review_avg_rating: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
"""Construct from the StoreSubmission Prisma view."""
|
||||
return cls(
|
||||
listing_id=_sub.listing_id,
|
||||
user_id=_sub.user_id,
|
||||
slug=_sub.slug,
|
||||
listing_version_id=_sub.listing_version_id,
|
||||
listing_version=_sub.listing_version,
|
||||
graph_id=_sub.graph_id,
|
||||
graph_version=_sub.graph_version,
|
||||
name=_sub.name,
|
||||
sub_heading=_sub.sub_heading,
|
||||
description=_sub.description,
|
||||
instructions=_sub.instructions,
|
||||
categories=_sub.categories,
|
||||
image_urls=_sub.image_urls,
|
||||
video_url=_sub.video_url,
|
||||
agent_output_demo_url=_sub.agent_output_demo_url,
|
||||
submitted_at=_sub.submitted_at,
|
||||
changes_summary=_sub.changes_summary,
|
||||
status=_sub.status,
|
||||
reviewed_at=_sub.reviewed_at,
|
||||
reviewer_id=_sub.reviewer_id,
|
||||
review_comments=_sub.review_comments,
|
||||
run_count=_sub.run_count,
|
||||
review_count=_sub.review_count,
|
||||
review_avg_rating=_sub.review_avg_rating,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
"""
|
||||
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
||||
"""
|
||||
if not (_l := _lv.StoreListing):
|
||||
raise ValueError("StoreListingVersion must have included StoreListing")
|
||||
|
||||
return cls(
|
||||
listing_id=_l.id,
|
||||
user_id=_l.owningUserId,
|
||||
slug=_l.slug,
|
||||
listing_version_id=_lv.id,
|
||||
listing_version=_lv.version,
|
||||
graph_id=_lv.agentGraphId,
|
||||
graph_version=_lv.agentGraphVersion,
|
||||
name=_lv.name,
|
||||
sub_heading=_lv.subHeading,
|
||||
description=_lv.description,
|
||||
instructions=_lv.instructions,
|
||||
categories=_lv.categories,
|
||||
image_urls=_lv.imageUrls,
|
||||
video_url=_lv.videoUrl,
|
||||
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
||||
submitted_at=_lv.submittedAt,
|
||||
changes_summary=_lv.changesSummary,
|
||||
status=_lv.submissionStatus,
|
||||
reviewed_at=_lv.reviewedAt,
|
||||
reviewer_id=_lv.reviewerId,
|
||||
review_comments=_lv.reviewComments,
|
||||
)
|
||||
# Additional fields for editing
|
||||
video_url: str | None = None
|
||||
agent_output_demo_url: str | None = None
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
@@ -273,12 +144,33 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreListingWithVersions(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
slug: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmission | None = None
|
||||
versions: list[StoreSubmission] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersions]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
graph_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Graph ID cannot be empty"
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
)
|
||||
graph_version: int = pydantic.Field(
|
||||
..., gt=0, description="Graph version must be greater than 0"
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
@@ -306,42 +198,12 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class StoreSubmissionAdminView(StoreSubmission):
|
||||
internal_comments: str | None # Private admin notes
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_db(_sub).model_dump(),
|
||||
internal_comments=_sub.internal_comments,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||
return cls(
|
||||
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
||||
internal_comments=_lv.internalComments,
|
||||
)
|
||||
|
||||
|
||||
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
||||
"""A store listing with its version history"""
|
||||
|
||||
listing_id: str
|
||||
graph_id: str
|
||||
slug: str
|
||||
active_listing_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
creator_email: str | None = None
|
||||
latest_version: StoreSubmissionAdminView | None = None
|
||||
versions: list[StoreSubmissionAdminView] = []
|
||||
|
||||
|
||||
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
||||
"""Response model for listings with version history"""
|
||||
|
||||
listings: list[StoreListingWithVersionsAdminView]
|
||||
pagination: Pagination
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
from . import model as store_model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = store_model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = store_model.StoreAgentsResponse(
|
||||
agents=[
|
||||
store_model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = store_model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_output_demo="demo.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = store_model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = store_model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
@@ -1,17 +1,16 @@
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
from typing import Literal
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
from fastapi import Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.data.graph
|
||||
import backend.util.json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
@@ -35,15 +34,22 @@ router = fastapi.APIRouter()
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.ProfileDetails,
|
||||
)
|
||||
async def get_profile(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Get the profile details for the authenticated user."""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
profile = await store_db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
raise NotFoundError("User does not have a profile yet")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
return profile
|
||||
|
||||
|
||||
@@ -51,17 +57,98 @@ async def get_profile(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: store_model.Profile,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.ProfileDetails:
|
||||
"""Update the store profile for the authenticated user."""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||
return updated_profile
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentsResponse,
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
##############################################
|
||||
############### Search Endpoints #############
|
||||
##############################################
|
||||
@@ -71,30 +158,60 @@ async def update_or_create_profile(
|
||||
"/search",
|
||||
summary="Unified search across all content types",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.UnifiedSearchResponse,
|
||||
)
|
||||
async def unified_search(
|
||||
query: str,
|
||||
content_types: list[prisma.enums.ContentType] | None = Query(
|
||||
content_types: list[str] | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Content types to search. If not specified, searches all.",
|
||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||
),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
user_id: str | None = Security(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user_id: str | None = fastapi.Security(
|
||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||
),
|
||||
) -> store_model.UnifiedSearchResponse:
|
||||
):
|
||||
"""
|
||||
Search across all content types (marketplace agents, blocks, documentation)
|
||||
using hybrid search.
|
||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||
|
||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||
page: Page number for pagination (default 1)
|
||||
page_size: Number of results per page (default 20)
|
||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||
|
||||
Returns:
|
||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
# Convert string content types to enum
|
||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||
if content_types:
|
||||
try:
|
||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||
)
|
||||
|
||||
# Perform unified hybrid search
|
||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=content_types,
|
||||
content_types=content_type_enums,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -128,69 +245,22 @@ async def unified_search(
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def get_agents(
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured agents"
|
||||
),
|
||||
creator: str | None = Query(
|
||||
default=None, description="Filter agents by creator username"
|
||||
),
|
||||
category: str | None = Query(default=None, description="Filter agents by category"),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
||||
default=None,
|
||||
description="Property to sort results by. Ignored if search_query is provided.",
|
||||
),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the marketplace,
|
||||
with optional filtering and sorting.
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent_by_name(
|
||||
async def get_agent(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
include_changelog: bool = Query(default=False),
|
||||
) -> store_model.StoreAgentDetails:
|
||||
"""Get details of a marketplace agent"""
|
||||
include_changelog: bool = fastapi.Query(default=False),
|
||||
):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
@@ -200,82 +270,76 @@ async def get_agent_by_name(
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreAgentDetails,
|
||||
)
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreReview,
|
||||
)
|
||||
async def post_user_review_for_agent(
|
||||
async def create_review(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: store_model.StoreReviewCreate,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreReview:
|
||||
"""Post a user review on a marketplace agent listing"""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
|
||||
# Create the review
|
||||
created_review = await store_db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_agent_by_listing_version(
|
||||
store_listing_version_id: str,
|
||||
) -> store_model.StoreAgentDetails:
|
||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||
"""Get outline of graph belonging to a specific marketplace listing version"""
|
||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
"/listings/versions/{store_listing_version_id}/graph/download",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str,
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""Download agent graph file for a specific marketplace listing version"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
@@ -285,19 +349,37 @@ async def download_agent_file(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorsResponse,
|
||||
)
|
||||
async def get_creators(
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to only show featured creators"
|
||||
),
|
||||
search_query: str | None = Query(
|
||||
default=None, description="Literal + semantic search on names and descriptions"
|
||||
),
|
||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.CreatorsResponse:
|
||||
"""List or search marketplace creators"""
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
@@ -309,12 +391,18 @@ async def get_creators(
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators/{username}",
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||
"""Get details on a marketplace creator"""
|
||||
async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
- Creator Details Page
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
@@ -326,17 +414,20 @@ async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||
|
||||
|
||||
@router.get(
|
||||
"/my-unpublished-agents",
|
||||
"/myagents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.MyAgentsResponse,
|
||||
)
|
||||
async def get_my_unpublished_agents(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.MyUnpublishedAgentsResponse:
|
||||
"""List the authenticated user's unpublished agents"""
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
|
||||
@@ -345,17 +436,28 @@ async def get_my_unpublished_agents(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=bool,
|
||||
)
|
||||
async def delete_submission(
|
||||
submission_id: str,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> bool:
|
||||
"""Delete a marketplace listing submission"""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
result = await store_db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -363,14 +465,37 @@ async def delete_submission(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmissionsResponse,
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = Query(ge=1, default=1),
|
||||
page_size: int = Query(ge=1, default=20),
|
||||
) -> store_model.StoreSubmissionsResponse:
|
||||
"""List the authenticated user's marketplace listing submissions"""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
listings = await store_db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
@@ -383,17 +508,30 @@ async def get_submissions(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: store_model.StoreSubmissionRequest,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Submit a new marketplace listing for review"""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
result = await store_db.create_store_submission(
|
||||
user_id=user_id,
|
||||
graph_id=submission_request.graph_id,
|
||||
graph_version=submission_request.graph_version,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
@@ -406,6 +544,7 @@ async def create_submission(
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -413,14 +552,28 @@ async def create_submission(
|
||||
"/submissions/{store_listing_version_id}",
|
||||
summary="Edit store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
response_model=store_model.StoreSubmission,
|
||||
)
|
||||
async def edit_submission(
|
||||
store_listing_version_id: str,
|
||||
submission_request: store_model.StoreSubmissionEditRequest,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> store_model.StoreSubmission:
|
||||
"""Update a pending marketplace listing submission"""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to edit
|
||||
submission_request (StoreSubmissionRequest): The updated submission details
|
||||
user_id (str): ID of the authenticated user editing the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The updated store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
result = await store_db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
@@ -435,6 +588,7 @@ async def edit_submission(
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -442,61 +596,115 @@ async def edit_submission(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> str:
|
||||
"""Upload media for a marketplace listing submission"""
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||
return media_url
|
||||
|
||||
|
||||
class ImageURLResponse(BaseModel):
|
||||
image_url: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||
)
|
||||
async def generate_image(
|
||||
graph_id: str,
|
||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||
) -> ImageURLResponse:
|
||||
agent_id: str,
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
) -> fastapi.responses.Response:
|
||||
"""
|
||||
Generate an image for a marketplace listing submission based on the properties
|
||||
of a given graph.
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
"""
|
||||
graph = await backend.data.graph.get_graph(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
agent = await backend.data.graph.get_graph(
|
||||
graph_id=agent_id, version=None, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{graph_id}.jpeg"
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent graph {graph_id}")
|
||||
return ImageURLResponse(image_url=existing_url)
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return ImageURLResponse(image_url=image_url)
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): The ID of the agent to download
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
|
||||
@@ -8,8 +8,6 @@ import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
|
||||
from . import model as store_model
|
||||
from . import routes as store_routes
|
||||
|
||||
@@ -198,7 +196,7 @@ def test_get_agents_sorted(
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by=StoreAgentsSortOptions.RUNS,
|
||||
sorted_by="runs",
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
@@ -382,11 +380,9 @@ def test_get_agent_details(
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
graph_versions=["1", "2"],
|
||||
graph_id="test-graph-id",
|
||||
agentGraphVersions=["1", "2"],
|
||||
agentGraphId="test-graph-id",
|
||||
last_updated=FIXED_NOW,
|
||||
active_version_id="test-version-id",
|
||||
has_approved_version=True,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
@@ -439,17 +435,15 @@ def test_get_creators_pagination(
|
||||
) -> None:
|
||||
mocked_value = store_model.CreatorsResponse(
|
||||
creators=[
|
||||
store_model.CreatorDetails(
|
||||
store_model.Creator(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
description=f"Creator {i} description",
|
||||
links=[f"user{i}.link.com"],
|
||||
is_featured=False,
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_runs=100,
|
||||
agent_rating=4.5,
|
||||
top_categories=["cat1", "cat2", "cat3"],
|
||||
agent_runs=100,
|
||||
is_featured=False,
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
@@ -502,19 +496,19 @@ def test_get_creator_details(
|
||||
mocked_value = store_model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
avatar_url="avatar.jpg",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
is_featured=True,
|
||||
num_agents=5,
|
||||
agent_runs=1000,
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.api.features.store.db.get_store_creator_details"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators/creator1")
|
||||
response = client.get("/creator/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = store_model.CreatorDetails.model_validate(response.json())
|
||||
@@ -534,26 +528,19 @@ def test_get_submissions_success(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
user_id="test-user-id",
|
||||
slug="test-agent",
|
||||
listing_version_id="test-version-id",
|
||||
listing_version=1,
|
||||
graph_id="test-agent-id",
|
||||
graph_version=1,
|
||||
name="Test Agent",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
instructions="Click the button!",
|
||||
categories=["test-category"],
|
||||
image_urls=["test.jpg"],
|
||||
video_url="test.mp4",
|
||||
agent_output_demo_url="demo_video.mp4",
|
||||
submitted_at=FIXED_NOW,
|
||||
changes_summary="Initial Submission",
|
||||
date_submitted=FIXED_NOW,
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
run_count=50,
|
||||
review_count=5,
|
||||
review_avg_rating=4.2,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
slug="test-agent",
|
||||
video_url="test.mp4",
|
||||
categories=["test-category"],
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -11,7 +11,6 @@ import pytest
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from . import cache as store_cache
|
||||
from .db import StoreAgentsSortOptions
|
||||
from .model import StoreAgent, StoreAgentsResponse
|
||||
|
||||
|
||||
@@ -216,7 +215,7 @@ class TestCacheDeletion:
|
||||
await store_cache._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
sorted_by="rating",
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -228,7 +227,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
sorted_by="rating",
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
@@ -240,7 +239,7 @@ class TestCacheDeletion:
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by=StoreAgentsSortOptions.RATING,
|
||||
sorted_by="rating",
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
|
||||
@@ -126,9 +126,6 @@ v1_router = APIRouter()
|
||||
########################################################
|
||||
|
||||
|
||||
_tally_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
@@ -137,24 +134,6 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
from backend.data.tally import populate_understanding_from_tally
|
||||
|
||||
task = asyncio.create_task(
|
||||
populate_understanding_from_tally(user.id, user.email)
|
||||
)
|
||||
_tally_background_tasks.add(task)
|
||||
task.add_done_callback(_tally_background_tasks.discard)
|
||||
except Exception:
|
||||
logger.debug("Failed to start Tally population task", exc_info=True)
|
||||
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@@ -449,6 +428,7 @@ async def execute_graph_block(
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
file: UploadFile = File(...),
|
||||
provider: str = "gcs",
|
||||
expiration_hours: int = 24,
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
@@ -511,6 +491,7 @@ async def upload_file(
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -43,7 +43,6 @@ def test_get_or_create_user_route(
|
||||
) -> None:
|
||||
"""Test get or create user endpoint"""
|
||||
mock_user = Mock()
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
mock_user.model_dump.return_value = {
|
||||
"id": test_user_id,
|
||||
"email": "test@example.com",
|
||||
@@ -515,6 +514,7 @@ async def test_upload_file_success(test_user_id: str):
|
||||
result = await upload_file(
|
||||
file=upload_file_mock,
|
||||
user_id=test_user_id,
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
)
|
||||
|
||||
@@ -532,6 +532,7 @@ async def test_upload_file_success(test_user_id: str):
|
||||
mock_handler.store_file.assert_called_once_with(
|
||||
content=file_content,
|
||||
filename="test.txt",
|
||||
provider="gcs",
|
||||
expiration_hours=24,
|
||||
user_id=test_user_id,
|
||||
)
|
||||
|
||||
@@ -3,29 +3,15 @@ Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
get_or_create_workspace,
|
||||
get_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_total_size,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -112,25 +98,6 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
raise
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class DeleteFileResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
used_percent: float
|
||||
file_count: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
@@ -153,148 +120,3 @@ async def download_file(
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/files/{file_id}",
|
||||
summary="Delete a workspace file",
|
||||
)
|
||||
async def delete_workspace_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> DeleteFileResponse:
|
||||
"""
|
||||
Soft-delete a workspace file and attempt to remove it from storage.
|
||||
|
||||
Used when a user clears a file input in the builder.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
manager = WorkspaceManager(user_id, workspace.id)
|
||||
deleted = await manager.delete_file(file_id)
|
||||
if not deleted:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return DeleteFileResponse(deleted=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
filename = os.path.basename(file.filename or "upload") or "upload"
|
||||
|
||||
# Read file content with early abort on size limit
|
||||
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
||||
chunks: list[bytes] = []
|
||||
total_size = 0
|
||||
while chunk := await file.read(64 * 1024): # 64KB chunks
|
||||
total_size += len(chunk)
|
||||
if total_size > max_file_bytes:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
|
||||
# Get or create workspace
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
current_usage = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded",
|
||||
"used_bytes": current_usage,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
"used_percent": round(used_percent, 1),
|
||||
},
|
||||
)
|
||||
|
||||
# Warn at 80% usage
|
||||
if (
|
||||
storage_limit_bytes
|
||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||
):
|
||||
logger.warning(
|
||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||
)
|
||||
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
try:
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
except ValueError as e:
|
||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded (concurrent upload)",
|
||||
"used_bytes": new_total,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_id=workspace_file.id,
|
||||
name=workspace_file.name,
|
||||
path=workspace_file.path,
|
||||
mime_type=workspace_file.mime_type,
|
||||
size_bytes=workspace_file.size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> StorageUsageResponse:
|
||||
"""
|
||||
Get storage usage information for the user's workspace.
|
||||
"""
|
||||
config = Config()
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
used_bytes = await get_workspace_total_size(workspace.id)
|
||||
file_count = await count_workspace_files(workspace.id)
|
||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
|
||||
return StorageUsageResponse(
|
||||
used_bytes=used_bytes,
|
||||
limit_bytes=limit_bytes,
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
@@ -1,359 +0,0 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["file_id"] == "file-aaa-bbb"
|
||||
assert data["name"] == "hello.txt"
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
cfg.return_value.max_workspace_storage_mb = 500
|
||||
|
||||
response = _upload(content=b"x" * 1024)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="data.xyz", content=b"arbitrary")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(
|
||||
filename="Makefile",
|
||||
content=b"all:\n\techo hello",
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_manager.write_file.assert_called_once()
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---- Delete ----
|
||||
|
||||
|
||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting an existing file should return {"deleted": true}."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"deleted": True}
|
||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||
|
||||
|
||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting a non-existent file should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = client.delete("/files/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
assert "File not found" in response.text
|
||||
|
||||
|
||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""Deleting when user has no workspace should return 404."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.delete("/files/file-aaa-bbb")
|
||||
assert response.status_code == 404
|
||||
assert "Workspace not found" in response.text
|
||||
@@ -41,11 +41,11 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.library.exceptions import (
|
||||
FolderAlreadyExistsError,
|
||||
FolderValidationError,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -55,7 +55,6 @@ from backend.util.exceptions import (
|
||||
MissingConfigError,
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
PreconditionFailed,
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
@@ -124,9 +123,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
@@ -266,17 +277,12 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
|
||||
@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
|
||||
# Run the last process in the foreground.
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in reversed(processes):
|
||||
for process in processes:
|
||||
try:
|
||||
process.stop()
|
||||
except Exception as e:
|
||||
|
||||
@@ -418,8 +418,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
|
||||
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -472,8 +470,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
|
||||
@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
|
||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code( # type: ignore[attr-defined]
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
|
||||
@@ -31,7 +31,6 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
@@ -117,7 +116,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -276,9 +274,6 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-6
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
@@ -826,7 +821,7 @@ async def llm_call(
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
|
||||
dropdown and the input/output schema adapts dynamically.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -19,11 +20,6 @@ from backend.blocks._base import (
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.helpers import (
|
||||
auto_lookup_mcp_credential,
|
||||
normalize_mcp_url,
|
||||
parse_mcp_content,
|
||||
)
|
||||
from backend.data.block import BlockInput, BlockOutput
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
@@ -183,7 +179,31 @@ class MCPToolBlock(Block):
|
||||
f"{error_text or 'Unknown error'}"
|
||||
)
|
||||
|
||||
return parse_mcp_content(result.content)
|
||||
# Extract text content from the result
|
||||
output_parts = []
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
# Try to parse as JSON for structured output
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item.get("type") == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item.get("type") == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
# If single result, unwrap
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
@staticmethod
|
||||
async def _auto_lookup_credential(
|
||||
@@ -191,10 +211,37 @@ class MCPToolBlock(Block):
|
||||
) -> "OAuth2Credentials | None":
|
||||
"""Auto-lookup stored MCP credential for a server URL.
|
||||
|
||||
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
|
||||
The caller should pass a normalized URL.
|
||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||
set (e.g. nodes created before the credential field was wired up).
|
||||
"""
|
||||
return await auto_lookup_mcp_credential(user_id, server_url)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info(
|
||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||
)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -231,7 +278,7 @@ class MCPToolBlock(Block):
|
||||
# the stored MCP credential for this server URL.
|
||||
if credentials is None:
|
||||
credentials = await self._auto_lookup_credential(
|
||||
user_id, normalize_mcp_url(input_data.server_url)
|
||||
user_id, input_data.server_url
|
||||
)
|
||||
|
||||
auth_token = (
|
||||
|
||||
@@ -55,9 +55,7 @@ class MCPClient:
|
||||
server_url: str,
|
||||
auth_token: str | None = None,
|
||||
):
|
||||
from backend.blocks.mcp.helpers import normalize_mcp_url
|
||||
|
||||
self.server_url = normalize_mcp_url(server_url)
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.auth_token = auth_token
|
||||
self._request_id = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_mcp_url(url: str) -> str:
|
||||
"""Normalize an MCP server URL for consistent credential matching.
|
||||
|
||||
Strips leading/trailing whitespace and a single trailing slash so that
|
||||
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
|
||||
the same stored credential.
|
||||
"""
|
||||
return url.strip().rstrip("/")
|
||||
|
||||
|
||||
def server_host(server_url: str) -> str:
|
||||
"""Extract the hostname from a server URL for display purposes.
|
||||
|
||||
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
|
||||
username/password before surfacing the value in UI messages.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(server_url)
|
||||
return parsed.hostname or server_url
|
||||
except Exception:
|
||||
return server_url
|
||||
|
||||
|
||||
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
|
||||
"""Parse MCP tool response content into a plain Python value.
|
||||
|
||||
- text items: parsed as JSON when possible, kept as str otherwise
|
||||
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
|
||||
- resource items: unwrapped to their resource payload dict
|
||||
|
||||
Single-item responses are unwrapped from the list; multiple items are
|
||||
returned as a list; empty content returns ``None``.
|
||||
"""
|
||||
output_parts: list[Any] = []
|
||||
for item in content:
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text = item.get("text", "")
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item_type == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item_type == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts or None
|
||||
|
||||
|
||||
async def auto_lookup_mcp_credential(
|
||||
user_id: str, server_url: str
|
||||
) -> OAuth2Credentials | None:
|
||||
"""Look up the best stored MCP credential for *server_url*.
|
||||
|
||||
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
|
||||
so the comparison with ``mcp_server_url`` in credential metadata matches.
|
||||
|
||||
Returns the credential with the latest ``access_token_expires_at``, refreshed
|
||||
if needed, or ``None`` when no match is found.
|
||||
"""
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Collect all matching credentials and pick the best one.
|
||||
# Primary sort: latest access_token_expires_at (tokens with expiry
|
||||
# are preferred over non-expiring ones). Secondary sort: last in
|
||||
# iteration order, which corresponds to the most recently created
|
||||
# row — this acts as a tiebreaker when multiple bearer tokens have
|
||||
# no expiry (e.g. after a failed old-credential cleanup).
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
>= (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Unit tests for the shared MCP helpers."""
|
||||
|
||||
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# normalize_mcp_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_trailing_slash():
|
||||
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_whitespace():
|
||||
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_both():
|
||||
assert (
|
||||
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_noop():
|
||||
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
|
||||
|
||||
|
||||
def test_normalize_path_with_trailing_slash():
|
||||
assert (
|
||||
normalize_mcp_url("https://mcp.example.com/path/")
|
||||
== "https://mcp.example.com/path"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# server_host
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_server_host_standard_url():
|
||||
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_strips_credentials():
|
||||
"""hostname must not expose user:pass."""
|
||||
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_with_port():
|
||||
"""Port should not appear in hostname (hostname strips it)."""
|
||||
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
|
||||
|
||||
|
||||
def test_server_host_fallback():
|
||||
"""Falls back to the raw string for un-parseable URLs."""
|
||||
assert server_host("not-a-url") == "not-a-url"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_mcp_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_text_plain():
|
||||
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
|
||||
|
||||
|
||||
def test_parse_text_json():
|
||||
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
|
||||
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
|
||||
|
||||
|
||||
def test_parse_image():
|
||||
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
|
||||
assert parse_mcp_content(content) == {
|
||||
"type": "image",
|
||||
"data": "abc123==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
|
||||
def test_parse_resource():
|
||||
content = [
|
||||
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
|
||||
]
|
||||
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
|
||||
|
||||
|
||||
def test_parse_multi_item():
|
||||
content = [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
assert parse_mcp_content(content) == ["first", "second"]
|
||||
|
||||
|
||||
def test_parse_empty():
|
||||
assert parse_mcp_content([]) is None
|
||||
@@ -21,7 +21,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
@@ -137,7 +136,7 @@ class PerplexityBlock(Block):
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
|
||||
@@ -83,8 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
|
||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
@@ -138,7 +137,7 @@ class StagehandObserveBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -228,7 +227,7 @@ class StagehandActBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -325,7 +324,7 @@ class StagehandExtractBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: StoreAgentsSortOptions = SchemaField(
|
||||
description="How to sort the results", default=StoreAgentsSortOptions.RATING
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
|
||||
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
"""
|
||||
Telegram Bot API helper functions.
|
||||
|
||||
Provides utilities for making authenticated requests to the Telegram Bot API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TELEGRAM_API_BASE = "https://api.telegram.org"
|
||||
|
||||
|
||||
class TelegramMessageResult(BaseModel, extra="allow"):
|
||||
"""Result from Telegram send/edit message API calls."""
|
||||
|
||||
message_id: int = 0
|
||||
chat: dict[str, Any] = {}
|
||||
date: int = 0
|
||||
text: str = ""
|
||||
|
||||
|
||||
class TelegramFileResult(BaseModel, extra="allow"):
|
||||
"""Result from Telegram getFile API call."""
|
||||
|
||||
file_id: str = ""
|
||||
file_unique_id: str = ""
|
||||
file_size: int = 0
|
||||
file_path: str = ""
|
||||
|
||||
|
||||
class TelegramAPIException(ValueError):
|
||||
"""Exception raised for Telegram API errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: int = 0):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
def get_bot_api_url(bot_token: str, method: str) -> str:
|
||||
"""Construct Telegram Bot API URL for a method."""
|
||||
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
|
||||
|
||||
|
||||
def get_file_url(bot_token: str, file_path: str) -> str:
|
||||
"""Construct Telegram file download URL."""
|
||||
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
|
||||
|
||||
|
||||
async def call_telegram_api(
|
||||
credentials: APIKeyCredentials,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> TelegramMessageResult:
|
||||
"""
|
||||
Make a request to the Telegram Bot API.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
method: API method name (e.g., "sendMessage", "getFile")
|
||||
data: Request parameters
|
||||
|
||||
Returns:
|
||||
API response result
|
||||
|
||||
Raises:
|
||||
TelegramAPIException: If the API returns an error
|
||||
"""
|
||||
token = credentials.api_key.get_secret_value()
|
||||
url = get_bot_api_url(token, method)
|
||||
|
||||
response = await Requests().post(url, json=data or {})
|
||||
result = response.json()
|
||||
|
||||
if not result.get("ok"):
|
||||
error_code = result.get("error_code", 0)
|
||||
description = result.get("description", "Unknown error")
|
||||
raise TelegramAPIException(description, error_code)
|
||||
|
||||
return TelegramMessageResult(**result.get("result", {}))
|
||||
|
||||
|
||||
async def call_telegram_api_with_file(
|
||||
credentials: APIKeyCredentials,
|
||||
method: str,
|
||||
file_field: str,
|
||||
file_data: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> TelegramMessageResult:
|
||||
"""
|
||||
Make a multipart/form-data request to the Telegram Bot API with a file upload.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
method: API method name (e.g., "sendPhoto", "sendVoice")
|
||||
file_field: Form field name for the file (e.g., "photo", "voice")
|
||||
file_data: Raw file bytes
|
||||
filename: Filename for the upload
|
||||
content_type: MIME type of the file
|
||||
data: Additional form parameters
|
||||
|
||||
Returns:
|
||||
API response result
|
||||
|
||||
Raises:
|
||||
TelegramAPIException: If the API returns an error
|
||||
"""
|
||||
token = credentials.api_key.get_secret_value()
|
||||
url = get_bot_api_url(token, method)
|
||||
|
||||
files = [(file_field, (filename, BytesIO(file_data), content_type))]
|
||||
|
||||
response = await Requests().post(url, files=files, data=data or {})
|
||||
result = response.json()
|
||||
|
||||
if not result.get("ok"):
|
||||
error_code = result.get("error_code", 0)
|
||||
description = result.get("description", "Unknown error")
|
||||
raise TelegramAPIException(description, error_code)
|
||||
|
||||
return TelegramMessageResult(**result.get("result", {}))
|
||||
|
||||
|
||||
async def get_file_info(
|
||||
credentials: APIKeyCredentials, file_id: str
|
||||
) -> TelegramFileResult:
|
||||
"""
|
||||
Get file information from Telegram.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
file_id: Telegram file_id from message
|
||||
|
||||
Returns:
|
||||
File info dict containing file_id, file_unique_id, file_size, file_path
|
||||
"""
|
||||
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
|
||||
return TelegramFileResult(**result.model_dump())
|
||||
|
||||
|
||||
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
|
||||
"""
|
||||
Get the download URL for a Telegram file.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
file_id: Telegram file_id from message
|
||||
|
||||
Returns:
|
||||
Full download URL
|
||||
"""
|
||||
token = credentials.api_key.get_secret_value()
|
||||
result = await get_file_info(credentials, file_id)
|
||||
file_path = result.file_path
|
||||
if not file_path:
|
||||
raise TelegramAPIException("No file_path returned from getFile")
|
||||
return get_file_url(token, file_path)
|
||||
|
||||
|
||||
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
|
||||
"""
|
||||
Download a file from Telegram servers.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
file_id: Telegram file_id
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
url = await get_file_download_url(credentials, file_id)
|
||||
response = await Requests().get(url)
|
||||
return response.content
|
||||
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
Telegram Bot credentials handling.
|
||||
|
||||
Telegram bots use an API key (bot token) obtained from @BotFather.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Bot token credentials (API key style)
|
||||
TelegramCredentials = APIKeyCredentials
|
||||
TelegramCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.TELEGRAM], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def TelegramCredentialsField() -> TelegramCredentialsInput:
|
||||
"""Creates a Telegram bot token credentials field."""
|
||||
return CredentialsField(
|
||||
description="Telegram Bot API token from @BotFather. "
|
||||
"Create a bot at https://t.me/BotFather to get your token."
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for unit tests
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="telegram",
|
||||
api_key=SecretStr("test_telegram_bot_token"),
|
||||
title="Mock Telegram Bot Token",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,377 +0,0 @@
|
||||
"""
|
||||
Telegram trigger blocks for receiving messages via webhooks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.telegram import TelegramWebhookType
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TelegramCredentialsField,
|
||||
TelegramCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Example payload for testing
|
||||
EXAMPLE_MESSAGE_PAYLOAD = {
|
||||
"update_id": 123456789,
|
||||
"message": {
|
||||
"message_id": 1,
|
||||
"from": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"username": "johndoe",
|
||||
"language_code": "en",
|
||||
},
|
||||
"chat": {
|
||||
"id": 12345678,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"username": "johndoe",
|
||||
"type": "private",
|
||||
},
|
||||
"date": 1234567890,
|
||||
"text": "Hello, bot!",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TelegramTriggerBase:
|
||||
"""Base class for Telegram trigger blocks."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: TelegramCredentialsInput = TelegramCredentialsField()
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
|
||||
|
||||
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
|
||||
"""
|
||||
Triggers when a message is received or edited in your Telegram bot.
|
||||
|
||||
Supports text, photos, voice messages, audio files, documents, and videos.
|
||||
Connect the outputs to other blocks to process messages and send responses.
|
||||
"""
|
||||
|
||||
class Input(TelegramTriggerBase.Input):
|
||||
class EventsFilter(BaseModel):
|
||||
"""Filter for message types to receive."""
|
||||
|
||||
text: bool = True
|
||||
photo: bool = False
|
||||
voice: bool = False
|
||||
audio: bool = False
|
||||
document: bool = False
|
||||
video: bool = False
|
||||
edited_message: bool = False
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Message Types", description="Types of messages to receive"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload from Telegram"
|
||||
)
|
||||
chat_id: int = SchemaField(
|
||||
description="The chat ID where the message was received. "
|
||||
"Use this to send replies."
|
||||
)
|
||||
message_id: int = SchemaField(description="The unique message ID")
|
||||
user_id: int = SchemaField(description="The user ID who sent the message")
|
||||
username: str = SchemaField(description="Username of the sender (may be empty)")
|
||||
first_name: str = SchemaField(description="First name of the sender")
|
||||
event: str = SchemaField(
|
||||
description="The message type (text, photo, voice, audio, etc.)"
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Text content of the message (for text messages)"
|
||||
)
|
||||
photo_file_id: str = SchemaField(
|
||||
description="File ID of the photo (for photo messages). "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
voice_file_id: str = SchemaField(
|
||||
description="File ID of the voice message (for voice messages). "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
audio_file_id: str = SchemaField(
|
||||
description="File ID of the audio file (for audio messages). "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
file_id: str = SchemaField(
|
||||
description="File ID for document/video messages. "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
file_name: str = SchemaField(
|
||||
description="Original filename (for document/audio messages)"
|
||||
)
|
||||
caption: str = SchemaField(description="Caption for media messages")
|
||||
is_edited: bool = SchemaField(
|
||||
description="Whether this is an edit of a previously sent message"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
|
||||
description="Triggers when a message is received or edited in your Telegram bot. "
|
||||
"Supports text, photos, voice messages, audio files, documents, and videos.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TelegramMessageTriggerBlock.Input,
|
||||
output_schema=TelegramMessageTriggerBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.TELEGRAM,
|
||||
webhook_type=TelegramWebhookType.BOT,
|
||||
resource_format="bot",
|
||||
event_filter_input="events",
|
||||
event_format="message.{event}",
|
||||
),
|
||||
test_input={
|
||||
"events": {"text": True, "photo": True},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"payload": EXAMPLE_MESSAGE_PAYLOAD,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("payload", EXAMPLE_MESSAGE_PAYLOAD),
|
||||
("chat_id", 12345678),
|
||||
("message_id", 1),
|
||||
("user_id", 12345678),
|
||||
("username", "johndoe"),
|
||||
("first_name", "John"),
|
||||
("is_edited", False),
|
||||
("event", "text"),
|
||||
("text", "Hello, bot!"),
|
||||
("photo_file_id", ""),
|
||||
("voice_file_id", ""),
|
||||
("audio_file_id", ""),
|
||||
("file_id", ""),
|
||||
("file_name", ""),
|
||||
("caption", ""),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
is_edited = "edited_message" in payload
|
||||
message = payload.get("message") or payload.get("edited_message", {})
|
||||
|
||||
# Extract common fields
|
||||
chat = message.get("chat", {})
|
||||
sender = message.get("from", {})
|
||||
|
||||
yield "payload", payload
|
||||
yield "chat_id", chat.get("id", 0)
|
||||
yield "message_id", message.get("message_id", 0)
|
||||
yield "user_id", sender.get("id", 0)
|
||||
yield "username", sender.get("username", "")
|
||||
yield "first_name", sender.get("first_name", "")
|
||||
yield "is_edited", is_edited
|
||||
|
||||
# For edited messages, yield event as "edited_message" and extract
|
||||
# all content fields from the edited message body
|
||||
if is_edited:
|
||||
yield "event", "edited_message"
|
||||
yield "text", message.get("text", "")
|
||||
photos = message.get("photo", [])
|
||||
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
|
||||
voice = message.get("voice", {})
|
||||
yield "voice_file_id", voice.get("file_id", "")
|
||||
audio = message.get("audio", {})
|
||||
yield "audio_file_id", audio.get("file_id", "")
|
||||
document = message.get("document", {})
|
||||
video = message.get("video", {})
|
||||
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
|
||||
yield "file_name", (
|
||||
document.get("file_name", "") or audio.get("file_name", "")
|
||||
)
|
||||
yield "caption", message.get("caption", "")
|
||||
# Determine message type and extract content
|
||||
elif "text" in message:
|
||||
yield "event", "text"
|
||||
yield "text", message.get("text", "")
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", ""
|
||||
elif "photo" in message:
|
||||
# Get the largest photo (last in array)
|
||||
photos = message.get("photo", [])
|
||||
photo_fid = photos[-1].get("file_id", "") if photos else ""
|
||||
yield "event", "photo"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", photo_fid
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "voice" in message:
|
||||
voice = message.get("voice", {})
|
||||
yield "event", "voice"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", voice.get("file_id", "")
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "audio" in message:
|
||||
audio = message.get("audio", {})
|
||||
yield "event", "audio"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", audio.get("file_id", "")
|
||||
yield "file_id", ""
|
||||
yield "file_name", audio.get("file_name", "")
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "document" in message:
|
||||
document = message.get("document", {})
|
||||
yield "event", "document"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", document.get("file_id", "")
|
||||
yield "file_name", document.get("file_name", "")
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "video" in message:
|
||||
video = message.get("video", {})
|
||||
yield "event", "video"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", video.get("file_id", "")
|
||||
yield "file_name", video.get("file_name", "")
|
||||
yield "caption", message.get("caption", "")
|
||||
else:
|
||||
yield "event", "other"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", ""
|
||||
|
||||
|
||||
# Example payload for reaction trigger testing
|
||||
EXAMPLE_REACTION_PAYLOAD = {
|
||||
"update_id": 123456790,
|
||||
"message_reaction": {
|
||||
"chat": {
|
||||
"id": 12345678,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"username": "johndoe",
|
||||
"type": "private",
|
||||
},
|
||||
"message_id": 42,
|
||||
"user": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "John",
|
||||
"username": "johndoe",
|
||||
},
|
||||
"date": 1234567890,
|
||||
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
|
||||
"old_reaction": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
|
||||
"""
|
||||
Triggers when a reaction to a message is changed.
|
||||
|
||||
Works automatically in private chats. In group chats, the bot must be
|
||||
an administrator to receive reaction updates.
|
||||
"""
|
||||
|
||||
class Input(TelegramTriggerBase.Input):
|
||||
pass
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload from Telegram"
|
||||
)
|
||||
chat_id: int = SchemaField(
|
||||
description="The chat ID where the reaction occurred"
|
||||
)
|
||||
message_id: int = SchemaField(description="The message ID that was reacted to")
|
||||
user_id: int = SchemaField(description="The user ID who changed the reaction")
|
||||
username: str = SchemaField(description="Username of the user (may be empty)")
|
||||
new_reactions: list = SchemaField(
|
||||
description="List of new reactions on the message"
|
||||
)
|
||||
old_reactions: list = SchemaField(
|
||||
description="List of previous reactions on the message"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82525328-9368-4966-8f0c-cd78e80181fd",
|
||||
description="Triggers when a reaction to a message is changed. "
|
||||
"Works in private chats automatically. "
|
||||
"In groups, the bot must be an administrator.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TelegramMessageReactionTriggerBlock.Input,
|
||||
output_schema=TelegramMessageReactionTriggerBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.TELEGRAM,
|
||||
webhook_type=TelegramWebhookType.BOT,
|
||||
resource_format="bot",
|
||||
event_filter_input="",
|
||||
event_format="message_reaction",
|
||||
),
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"payload": EXAMPLE_REACTION_PAYLOAD,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("payload", EXAMPLE_REACTION_PAYLOAD),
|
||||
("chat_id", 12345678),
|
||||
("message_id", 42),
|
||||
("user_id", 12345678),
|
||||
("username", "johndoe"),
|
||||
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
|
||||
("old_reactions", []),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
reaction = payload.get("message_reaction", {})
|
||||
|
||||
chat = reaction.get("chat", {})
|
||||
user = reaction.get("user", {})
|
||||
|
||||
yield "payload", payload
|
||||
yield "chat_id", chat.get("id", 0)
|
||||
yield "message_id", reaction.get("message_id", 0)
|
||||
yield "user_id", user.get("id", 0)
|
||||
yield "username", user.get("username", "")
|
||||
yield "new_reactions", reaction.get("new_reaction", [])
|
||||
yield "old_reactions", reaction.get("old_reaction", [])
|
||||
@@ -2,7 +2,6 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||
from backend.blocks.system.library_operations import (
|
||||
AddToLibraryFromStoreBlock,
|
||||
LibraryAgent,
|
||||
@@ -122,10 +121,7 @@ async def test_search_store_agents_block(mocker):
|
||||
)
|
||||
|
||||
input_data = block.Input(
|
||||
query="test",
|
||||
category="productivity",
|
||||
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
|
||||
limit=10,
|
||||
query="test", category="productivity", sort_by="rating", limit=10
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
|
||||
@@ -34,12 +34,10 @@ def main(output: Path, pretty: bool):
|
||||
"""Generate and output the OpenAPI JSON specification."""
|
||||
openapi_schema = get_openapi_schema()
|
||||
|
||||
json_output = json.dumps(
|
||||
openapi_schema, indent=2 if pretty else None, ensure_ascii=False
|
||||
)
|
||||
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
|
||||
|
||||
if output:
|
||||
output.write_text(json_output, encoding="utf-8")
|
||||
output.write_text(json_output)
|
||||
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
||||
click.echo(f"\n{json_output[:500]} ...")
|
||||
else:
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .service import stream_chat_completion_baseline
|
||||
|
||||
__all__ = ["stream_chat_completion_baseline"]
|
||||
@@ -1,424 +0,0 @@
|
||||
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
|
||||
|
||||
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
|
||||
Claude Agent SDK / Anthropic API is unavailable. Routes through any
|
||||
OpenAI-compatible provider (OpenRouter by default) and reuses the same
|
||||
shared tool registry as the SDK path.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
# Maximum number of tool-call rounds before forcing a text response.
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background."""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Failed to update session title: %s", e)
|
||||
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress session messages if they exceed the model's token limit.
|
||||
|
||||
Uses the shared compress_context() utility which supports LLM-based
|
||||
summarization of older messages while keeping recent ones intact,
|
||||
with progressive truncation and middle-out deletion as fallbacks.
|
||||
"""
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
"[Baseline] Context compacted: %d -> %d tokens "
|
||||
"(%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def stream_chat_completion_baseline(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
|
||||
Designed as a fallback when the Claude Agent SDK is unavailable.
|
||||
Uses the same tool registry as the SDK path but routes through any
|
||||
OpenAI-compatible provider (e.g. OpenRouter).
|
||||
|
||||
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
|
||||
"""
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append user message
|
||||
new_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_role
|
||||
and session.messages[-1].content == message
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_role, content=message))
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(message),
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Propagate user/session context to Langfuse so all LLM calls within
|
||||
# this request are grouped under a single trace with proper attribution.
|
||||
_trace_ctx: Any = None
|
||||
try:
|
||||
_trace_ctx = propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
trace_name="copilot-baseline",
|
||||
tags=["baseline"],
|
||||
)
|
||||
_trace_ctx.__enter__()
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
assistant_text = ""
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
yield StreamStartStep()
|
||||
step_open = True
|
||||
|
||||
# Stream a response from the model
|
||||
create_kwargs: dict[str, Any] = dict(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
# Text content
|
||||
if delta.content:
|
||||
if not text_started:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
round_text += delta.content
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.content)
|
||||
|
||||
# Tool call fragments (streamed incrementally)
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
entry = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
entry["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
entry["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
entry["arguments"] += tc.function.arguments
|
||||
|
||||
# Close text block if we had one this round
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_started = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Accumulate text for session persistence
|
||||
assistant_text += round_text
|
||||
|
||||
# No tool calls -> model is done
|
||||
if not tool_calls_by_index:
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
break
|
||||
|
||||
# Close step before tool execution
|
||||
yield StreamFinishStep()
|
||||
step_open = False
|
||||
|
||||
# Append the assistant message with tool_calls to context.
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if round_text:
|
||||
assistant_msg["content"] = round_text
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["arguments"] or "{}",
|
||||
},
|
||||
}
|
||||
for tc in tool_calls_by_index.values()
|
||||
]
|
||||
openai_messages.append(assistant_msg)
|
||||
|
||||
# Execute each tool call and stream events
|
||||
for tc in tool_calls_by_index.values():
|
||||
tool_call_id = tc["id"]
|
||||
tool_name = tc["name"]
|
||||
raw_args = tc["arguments"] or "{}"
|
||||
try:
|
||||
tool_args = orjson.loads(raw_args)
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = (
|
||||
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
)
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": parse_error,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
|
||||
# Execute via shared tool registry
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
yield result
|
||||
tool_output = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else str(result.output)
|
||||
)
|
||||
except Exception as e:
|
||||
error_output = f"Tool execution error: {e}"
|
||||
logger.error(
|
||||
"[Baseline] Tool %s failed: %s",
|
||||
tool_name,
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
tool_output = error_output
|
||||
|
||||
# Append tool result to context for next round
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# for-loop exhausted without break -> tool-round limit hit
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
"without a final response."
|
||||
)
|
||||
logger.error("[Baseline] %s", limit_msg)
|
||||
yield StreamError(
|
||||
errorText=limit_msg,
|
||||
code="baseline_tool_round_limit",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
if text_started:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
if step_open:
|
||||
yield StreamFinishStep()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=assistant_text)
|
||||
)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
yield StreamFinish()
|
||||
@@ -1,99 +0,0 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.model import (
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_baseline_multi_turn(setup_test_user, test_user_id):
|
||||
"""Test that the baseline LLM path streams responses and maintains history.
|
||||
|
||||
Turn 1: Send a message with a unique keyword.
|
||||
Turn 2: Ask the model to recall the keyword — proving conversation history
|
||||
is correctly passed to the single-call LLM.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "QUASAR99"
|
||||
turn1_msg = (
|
||||
f"Please remember this special keyword: {keyword}. "
|
||||
"Just confirm you've noted it, keep your response brief."
|
||||
)
|
||||
turn1_text = ""
|
||||
turn1_errors: list[str] = []
|
||||
got_start = False
|
||||
got_finish = False
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn1_msg,
|
||||
user_id=test_user_id,
|
||||
):
|
||||
if isinstance(chunk, StreamStart):
|
||||
got_start = True
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
turn1_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn1_errors.append(chunk.errorText)
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
got_finish = True
|
||||
|
||||
assert got_start, "Turn 1 did not yield StreamStart"
|
||||
assert got_finish, "Turn 1 did not yield StreamFinish"
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
logger.info(f"Turn 1 response: {turn1_text[:100]}")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert session, "Session not found after turn 1"
|
||||
|
||||
# Verify messages were persisted (user + assistant)
|
||||
assert (
|
||||
len(session.messages) >= 2
|
||||
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
|
||||
|
||||
# --- Turn 2: ask model to recall the keyword ---
|
||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
||||
turn2_text = ""
|
||||
turn2_errors: list[str] = []
|
||||
|
||||
async for chunk in stream_chat_completion_baseline(
|
||||
session.session_id,
|
||||
turn2_msg,
|
||||
user_id=test_user_id,
|
||||
session=session,
|
||||
):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
turn2_text += chunk.delta
|
||||
elif isinstance(chunk, StreamError):
|
||||
turn2_errors.append(chunk.errorText)
|
||||
|
||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
||||
assert turn2_text, "Turn 2 produced no text"
|
||||
assert keyword in turn2_text, (
|
||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
||||
f"Response: {turn2_text[:200]}"
|
||||
)
|
||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
||||
349
autogpt_platform/backend/backend/copilot/completion_consumer.py
Normal file
349
autogpt_platform/backend/backend/copilot/completion_consumer.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
Database operations are handled through the chat_db() accessor, which
|
||||
routes through DatabaseManager RPC when Prisma is not directly connected.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
await process_operation_success(task, message.result)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
await process_operation_failure(task, message.error)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
329
autogpt_platform/backend/backend/copilot/completion_handler.py
Normal file
329
autogpt_platform/backend/backend/copilot/completion_handler.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Update tool message in database using the chat_db accessor.
|
||||
|
||||
Routes through DatabaseManager RPC when Prisma is not directly
|
||||
connected (e.g. in the CoPilot Executor microservice).
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=content,
|
||||
)
|
||||
if not updated:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id="
|
||||
f"{tool_call_id} in session {session_id}"
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task
|
||||
will be marked as failed instead of completed.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database
|
||||
with the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -1,13 +1,10 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -22,18 +19,31 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
default="https://openrouter.ai/api/v1",
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for long-running operation deduplication lock "
|
||||
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
|
||||
"For longer operations, the stream_registry heartbeat keeps them alive.",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
@@ -49,14 +59,36 @@ class ChatConfig(BaseSettings):
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
session_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for session metadata hash keys",
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
turn_stream_prefix: str = Field(
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for turn message stream keys",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
@@ -65,15 +97,11 @@ class ChatConfig(BaseSettings):
|
||||
default="CoPilot Prompt",
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
langfuse_prompt_cache_ttl: int = Field(
|
||||
default=300,
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
@@ -87,88 +115,25 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
claude_agent_max_subtasks: int = Field(
|
||||
default=10,
|
||||
description="Max number of concurrent sub-agent Tasks the SDK can run per session.",
|
||||
description="Max number of sub-agent Tasks the SDK can spawn per session.",
|
||||
)
|
||||
claude_agent_use_resume: bool = Field(
|
||||
default=True,
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Use E2B cloud sandboxes for persistent bash/python execution. "
|
||||
"When enabled, bash_exec routes commands to E2B and SDK file tools "
|
||||
"operate directly on the sandbox via E2B's filesystem API.",
|
||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
||||
)
|
||||
e2b_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
|
||||
)
|
||||
e2b_sandbox_template: str = Field(
|
||||
default="base",
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=10800, # 3 hours — wall-clock timeout, not idle; explicit pause is primary
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
)
|
||||
e2b_sandbox_on_timeout: Literal["kill", "pause"] = Field(
|
||||
default="pause",
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
|
||||
Single source of truth for "should we use E2B right now?".
|
||||
Prefer this over combining ``use_e2b_sandbox`` and ``e2b_api_key``
|
||||
separately at call sites.
|
||||
"""
|
||||
return self.use_e2b_sandbox and bool(self.e2b_api_key)
|
||||
|
||||
@property
|
||||
def active_e2b_api_key(self) -> str | None:
|
||||
"""Return the E2B API key when E2B is enabled and configured, else None.
|
||||
|
||||
Combines the ``use_e2b_sandbox`` flag check and key presence into one.
|
||||
Use in callers::
|
||||
|
||||
if api_key := config.active_e2b_api_key:
|
||||
# E2B is active; api_key is narrowed to str
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
"""Get E2B API key from environment if not provided."""
|
||||
if not v:
|
||||
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if not v:
|
||||
if v is None:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
@@ -178,16 +143,13 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
|
||||
# The SDK CLI picks it up from the env directly. Including it
|
||||
# would pair it with the OpenRouter base_url, causing auth failures.
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if not v:
|
||||
if v is None:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("CHAT_BASE_URL")
|
||||
if not v:
|
||||
@@ -195,7 +157,15 @@ class ChatConfig(BaseSettings):
|
||||
if not v:
|
||||
v = os.getenv("OPENAI_BASE_URL")
|
||||
if not v:
|
||||
v = OPENROUTER_BASE_URL
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@@ -209,15 +179,6 @@ class ChatConfig(BaseSettings):
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Unit tests for ChatConfig."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
def test_both_enabled_and_key_present_returns_true(self):
|
||||
"""e2b_active is True when use_e2b_sandbox=True and e2b_api_key is set."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is True
|
||||
|
||||
def test_enabled_but_missing_key_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=True but e2b_api_key is absent."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=True, e2b_api_key=None)
|
||||
assert cfg.e2b_active is False
|
||||
|
||||
def test_disabled_returns_false(self):
|
||||
"""e2b_active is False when use_e2b_sandbox=False regardless of key."""
|
||||
cfg = ChatConfig(use_e2b_sandbox=False, e2b_api_key="test-key")
|
||||
assert cfg.e2b_active is False
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Shared constants for the CoPilot module."""
|
||||
|
||||
# Special message prefixes for text-based markers (parsed by frontend).
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Compaction notice messages shown to users.
|
||||
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
|
||||
COMPACTION_TOOL_NAME = "context_compaction"
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Shared execution context for copilot SDK tool handlers.
|
||||
|
||||
All context variables and their accessors live here so that
|
||||
``tool_adapter``, ``file_ref``, and ``e2b_file_tools`` can import them
|
||||
without creating circular dependencies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
# validation can scope tool-results reads to the current session.
|
||||
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
|
||||
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
"_current_sandbox", default=None
|
||||
)
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
sandbox: "AsyncSandbox | None" = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> None:
|
||||
"""Set per-turn context variables used by file-resolution tool handlers."""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_sandbox.set(sandbox)
|
||||
_current_sdk_cwd.set(sdk_cwd or "")
|
||||
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Return the current (user_id, session) pair for the active request."""
|
||||
return _current_user_id.get(), _current_session.get()
|
||||
|
||||
|
||||
def get_current_sandbox() -> "AsyncSandbox | None":
|
||||
"""Return the E2B sandbox for the current session, or None if not active."""
|
||||
return _current_sandbox.get()
|
||||
|
||||
|
||||
def get_sdk_cwd() -> str:
|
||||
"""Return the SDK working directory for the current session (empty string if unset)."""
|
||||
return _current_sdk_cwd.get()
|
||||
|
||||
|
||||
E2B_WORKDIR = "/home/user"
|
||||
|
||||
|
||||
def resolve_sandbox_path(path: str) -> str:
|
||||
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
|
||||
|
||||
Raises :class:`ValueError` if the resolved path escapes the sandbox.
|
||||
"""
|
||||
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
|
||||
normalized = os.path.normpath(candidate)
|
||||
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
|
||||
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
|
||||
return True
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,163 +0,0 @@
|
||||
"""Tests for context.py — execution context variables and path helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
|
||||
def _make_session() -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = "test-session"
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context variable getters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_execution_context_defaults():
|
||||
"""get_execution_context returns (None, session) when user_id is not set."""
|
||||
set_execution_context(None, _make_session())
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id is None
|
||||
assert session is not None
|
||||
|
||||
|
||||
def test_set_and_get_execution_context():
|
||||
"""set_execution_context stores user_id and session."""
|
||||
mock_session = _make_session()
|
||||
set_execution_context("user-abc", mock_session)
|
||||
user_id, session = get_execution_context()
|
||||
assert user_id == "user-abc"
|
||||
assert session is mock_session
|
||||
|
||||
|
||||
def test_get_current_sandbox_none_by_default():
|
||||
"""get_current_sandbox returns None when no sandbox is set."""
|
||||
set_execution_context("u1", _make_session(), sandbox=None)
|
||||
assert get_current_sandbox() is None
|
||||
|
||||
|
||||
def test_get_current_sandbox_returns_set_value():
|
||||
"""get_current_sandbox returns the sandbox set via set_execution_context."""
|
||||
mock_sandbox = MagicMock()
|
||||
set_execution_context("u1", _make_session(), sandbox=mock_sandbox)
|
||||
assert get_current_sandbox() is mock_sandbox
|
||||
|
||||
|
||||
def test_get_sdk_cwd_empty_when_not_set():
|
||||
"""get_sdk_cwd returns empty string when sdk_cwd is not set."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd=None)
|
||||
assert get_sdk_cwd() == ""
|
||||
|
||||
|
||||
def test_get_sdk_cwd_returns_set_value():
|
||||
"""get_sdk_cwd returns the value set via set_execution_context."""
|
||||
set_execution_context("u1", _make_session(), sdk_cwd="/tmp/copilot-test")
|
||||
assert get_sdk_cwd() == "/tmp/copilot-test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_allowed_local_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_allowed_local_path_empty():
|
||||
assert not is_allowed_local_path("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_inside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
path = os.path.join(cwd, "file.txt")
|
||||
assert is_allowed_local_path(path, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sdk_cwd_itself():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert is_allowed_local_path(cwd, cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_outside_sdk_cwd():
|
||||
with tempfile.TemporaryDirectory() as cwd:
|
||||
assert not is_allowed_local_path("/etc/passwd", cwd)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
"""Without sdk_cwd or project_dir, all paths are rejected."""
|
||||
_current_project_dir.set("")
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(sibling_path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_valid():
|
||||
assert (
|
||||
resolve_sandbox_path("/home/user/project/main.py")
|
||||
== "/home/user/project/main.py"
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_relative():
|
||||
assert resolve_sandbox_path("project/main.py") == "/home/user/project/main.py"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_workdir_itself():
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_normalizes_dots():
|
||||
assert resolve_sandbox_path("/home/user/a/../b") == "/home/user/b"
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_escape_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
|
||||
def test_resolve_sandbox_path_absolute_outside_raises():
|
||||
with pytest.raises(ValueError, match="/home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
@@ -16,7 +16,7 @@ from prisma.types import (
|
||||
)
|
||||
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson, sanitize_string
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
|
||||
@@ -81,35 +81,6 @@ async def update_chat_session(
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def update_chat_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
|
||||
Always filters by (session_id, user_id) so callers cannot mutate another
|
||||
user's session even when they know the session_id.
|
||||
|
||||
Args:
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
guard so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns True if a row was updated, False otherwise (session not found,
|
||||
wrong user, or — when only_if_empty — title was already set).
|
||||
"""
|
||||
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
||||
if only_if_empty:
|
||||
where["title"] = None
|
||||
result = await PrismaChatSession.prisma().update_many(
|
||||
where=where,
|
||||
data={"title": title, "updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
return result > 0
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
session_id: str,
|
||||
role: str,
|
||||
@@ -130,16 +101,15 @@ async def add_chat_message(
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
|
||||
# control characters (null bytes etc.) that may appear in tool outputs.
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = sanitize_string(content)
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = sanitize_string(refusal)
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
@@ -200,16 +170,15 @@ async def add_chat_messages_batch(
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields — sanitize to strip
|
||||
# PostgreSQL-incompatible control characters.
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = sanitize_string(msg["content"])
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = sanitize_string(msg["refusal"])
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
@@ -343,7 +312,7 @@ async def update_tool_message_content(
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": sanitize_string(new_content),
|
||||
"content": new_content,
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
|
||||
@@ -4,7 +4,6 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -26,7 +25,7 @@ from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_turn, init_worker
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
@@ -182,13 +181,13 @@ class CoPilotExecutor(AppProcess):
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for session_id, lock in list(self._task_locks.items()):
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
@@ -268,20 +267,20 @@ class CoPilotExecutor(AppProcess):
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
session_id = request.session_id
|
||||
if not session_id:
|
||||
logger.warning("Cancel message missing 'session_id'")
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
return
|
||||
if session_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {session_id} but not active")
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[session_id]
|
||||
logger.info(f"Received cancel for {session_id}")
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {session_id}")
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
@@ -353,12 +352,12 @@ class CoPilotExecutor(AppProcess):
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
session_id = entry.session_id
|
||||
task_id = entry.task_id
|
||||
|
||||
# Check for local duplicate - session is already running on this executor
|
||||
if session_id in self.active_tasks:
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Session {session_id} already running locally, rejecting duplicate"
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
@@ -366,69 +365,64 @@ class CoPilotExecutor(AppProcess):
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:session:{session_id}:lock",
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"Session {session_id} already running on pod {current_owner}"
|
||||
)
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {session_id} - Redis unavailable"
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[session_id] = cluster_lock
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {session_id}, "
|
||||
f"executor_id={self.executor_id}"
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_turn, entry, cancel_event, cluster_lock
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[session_id] = (future, cancel_event)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {session_id}: {e}")
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if session_id in self._task_locks:
|
||||
del self._task_locks[session_id]
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {session_id}")
|
||||
error_msg = None
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
error_msg = str(exec_error) or type(exec_error).__name__
|
||||
logger.error(f"Execution for {session_id} failed: {error_msg}")
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Run completion callback cancelled for {session_id}")
|
||||
except BaseException as e:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.exception(f"Error in run completion callback: {error_msg}")
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if session_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {session_id}")
|
||||
self._task_locks[session_id].release()
|
||||
del self._task_locks[session_id]
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
@@ -439,11 +433,11 @@ class CoPilotExecutor(AppProcess):
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
with self._active_tasks_lock:
|
||||
for session_id, (future, _) in list(self.active_tasks.items()):
|
||||
for task_id, (future, _) in list(self.active_tasks.items()):
|
||||
if future.done():
|
||||
completed_tasks.append(session_id)
|
||||
self.active_tasks.pop(session_id, None)
|
||||
logger.info(f"Cleaned up completed session {session_id}")
|
||||
completed_tasks.append(task_id)
|
||||
self.active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot session execution
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
@@ -34,17 +32,17 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_turn(
|
||||
def execute_copilot_task(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a single CoPilot turn (user message → AI response).
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The turn payload
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
@@ -78,16 +76,16 @@ def cleanup_worker():
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot sessions.
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation sessions. It maintains an async event loop for
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. Session entry is picked from RabbitMQ queue
|
||||
2. Manager submits to thread pool
|
||||
3. Processor executes in its event loop
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@@ -110,41 +108,8 @@ class CoPilotProcessor:
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
# Skip the SDK's per-request CLI version check — the bundled CLI is
|
||||
# already version-matched to the SDK package.
|
||||
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
|
||||
|
||||
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
|
||||
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
|
||||
self._prewarm_cli()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def _prewarm_cli(self) -> None:
|
||||
"""Run the bundled CLI binary once to warm OS page caches."""
|
||||
try:
|
||||
from claude_agent_sdk._internal.transport.subprocess_cli import (
|
||||
SubprocessCLITransport,
|
||||
)
|
||||
|
||||
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
|
||||
if cli_path:
|
||||
result = subprocess.run(
|
||||
[cli_path, "-v"],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
|
||||
result.returncode, # type: ignore[reportCallIssue]
|
||||
cli_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
@@ -154,16 +119,13 @@ class CoPilotProcessor:
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
coro.close() # Prevent "coroutine was never awaited" warning
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.warning(
|
||||
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
|
||||
)
|
||||
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
|
||||
|
||||
# Stop the event loop
|
||||
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
|
||||
@@ -177,17 +139,19 @@ class CoPilotProcessor:
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot turn.
|
||||
"""Execute a CoPilot task.
|
||||
|
||||
Runs the async logic in the worker's event loop and handles errors.
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The turn payload containing session and message info
|
||||
entry: The task payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
@@ -195,30 +159,38 @@ class CoPilotProcessor:
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
try:
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
@@ -227,26 +199,24 @@ class CoPilotProcessor:
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for a CoPilot turn.
|
||||
"""Async execution logic for CoPilot task.
|
||||
|
||||
Calls the chat completion service (SDK or baseline) and publishes
|
||||
results to the stream registry.
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The turn payload
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger
|
||||
log: Structured logger for this task
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
# Choose service based on LaunchDarkly flag
|
||||
config = ChatConfig()
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
use_sdk = await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
@@ -254,60 +224,68 @@ class CoPilotProcessor:
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
else copilot_service.stream_chat_completion
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
async for chunk in stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Skip StreamFinish — mark_session_completed publishes it.
|
||||
if isinstance(chunk, StreamFinish):
|
||||
continue
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
|
||||
# Stream loop completed
|
||||
if cancel.is_set():
|
||||
log.info("Stream cancelled by user")
|
||||
|
||||
except BaseException as e:
|
||||
# Handle all exceptions (including CancelledError) with appropriate logging
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
log.info("Turn cancelled")
|
||||
error_msg = "Operation cancelled"
|
||||
else:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
log.error(f"Turn failed: {error_msg}")
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id,
|
||||
status="failed",
|
||||
error_message="Task was cancelled",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# If no exception but user cancelled, still mark as cancelled
|
||||
if not error_msg and cancel.is_set():
|
||||
error_msg = "Operation cancelled"
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=error_msg
|
||||
)
|
||||
except Exception as mark_err:
|
||||
log.error(f"Failed to mark session completed: {mark_err}")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
|
||||
@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
@@ -135,15 +135,18 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID (also used for dedup/locking)"""
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
turn_id: str = ""
|
||||
"""Per-turn UUID for Redis stream isolation"""
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
@@ -153,50 +156,47 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
session_id: str
|
||||
"""Session ID to cancel"""
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_turn(
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
turn_id: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID (also used for dedup/locking)
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
@@ -207,15 +207,15 @@ async def enqueue_copilot_turn(
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(session_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot session.
|
||||
async def enqueue_cancel_task(task_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot task.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(session_id=session_id)
|
||||
event = CancelCoPilotEvent(task_id=task_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
|
||||
@@ -469,16 +469,8 @@ async def upsert_chat_session(
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed).
|
||||
# Title updates (update_session_title) run *outside* this lock because
|
||||
# they only touch the title field, not messages. So a concurrent rename
|
||||
# or auto-title may have written a newer title to Redis while this
|
||||
# upsert was in progress. Always prefer the cached title to avoid
|
||||
# overwriting it with the stale in-memory copy.
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
existing_cached = await _get_session_from_cache(session.session_id)
|
||||
if existing_cached and existing_cached.title:
|
||||
session = session.model_copy(update={"title": existing_cached.title})
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
@@ -680,47 +672,27 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
# Shut down any local browser daemon for this session (best-effort).
|
||||
# Inline import required: all tool modules import ChatSession from this
|
||||
# module, so any top-level import from tools.* would create a cycle.
|
||||
try:
|
||||
from .tools.agent_browser import close_browser_session
|
||||
|
||||
await close_browser_session(session_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Browser cleanup for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
title: str,
|
||||
*,
|
||||
only_if_empty: bool = False,
|
||||
) -> bool:
|
||||
"""Update the title of a chat session, scoped to the owning user.
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
Lightweight operation that doesn't touch messages, avoiding race conditions
|
||||
with concurrent message updates.
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
user_id: Owning user — the DB query filters on this.
|
||||
title: The new title to set.
|
||||
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
||||
so auto-generated titles never overwrite a user-set title.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise (not found, wrong user,
|
||||
or — when only_if_empty — title was already set).
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_chat_session_title(
|
||||
session_id, user_id, title, only_if_empty=only_if_empty
|
||||
)
|
||||
if not updated:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
@@ -732,8 +704,9 @@ async def update_session_title(
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Cache title update failed for session {session_id} (non-critical): {e}"
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Scheduler job to generate LLM-optimized block descriptions.
|
||||
|
||||
Runs periodically to rewrite block descriptions into concise, actionable
|
||||
summaries that help the copilot LLM pick the right blocks during agent
|
||||
generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from backend.blocks import get_blocks
|
||||
from backend.util.clients import get_database_manager_client, get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a technical writer for an automation platform. "
|
||||
"Rewrite the following block description to be concise (under 50 words), "
|
||||
"informative, and actionable. Focus on what the block does and when to "
|
||||
"use it. Output ONLY the rewritten description, nothing else. "
|
||||
"Do not use markdown formatting."
|
||||
)
|
||||
|
||||
# Rate-limit delay between sequential LLM calls (seconds)
|
||||
_RATE_LIMIT_DELAY = 0.5
|
||||
# Maximum tokens for optimized description generation
|
||||
_MAX_DESCRIPTION_TOKENS = 150
|
||||
# Model for generating optimized descriptions (fast, cheap)
|
||||
_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
async def _optimize_descriptions(blocks: list[dict[str, str]]) -> dict[str, str]:
|
||||
"""Call the shared OpenAI client to rewrite each block description."""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.error(
|
||||
"No OpenAI client configured, skipping block description optimization"
|
||||
)
|
||||
return {}
|
||||
|
||||
results: dict[str, str] = {}
|
||||
for block in blocks:
|
||||
block_id = block["id"]
|
||||
block_name = block["name"]
|
||||
description = block["description"]
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Block name: {block_name}\nDescription: {description}",
|
||||
},
|
||||
],
|
||||
max_tokens=_MAX_DESCRIPTION_TOKENS,
|
||||
)
|
||||
optimized = (response.choices[0].message.content or "").strip()
|
||||
if optimized:
|
||||
results[block_id] = optimized
|
||||
logger.debug("Optimized description for %s", block_name)
|
||||
else:
|
||||
logger.warning("Empty response for block %s", block_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to optimize description for %s", block_name, exc_info=True
|
||||
)
|
||||
|
||||
await asyncio.sleep(_RATE_LIMIT_DELAY)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def optimize_block_descriptions() -> dict[str, int]:
|
||||
"""Generate optimized descriptions for blocks that don't have one yet.
|
||||
|
||||
Uses the shared OpenAI client to rewrite block descriptions into concise
|
||||
summaries suitable for agent generation prompts.
|
||||
|
||||
Returns:
|
||||
Dict with counts: processed, success, failed, skipped.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
|
||||
blocks = db_client.get_blocks_needing_optimization()
|
||||
if not blocks:
|
||||
logger.info("All blocks already have optimized descriptions")
|
||||
return {"processed": 0, "success": 0, "failed": 0, "skipped": 0}
|
||||
|
||||
logger.info("Found %d blocks needing optimized descriptions", len(blocks))
|
||||
|
||||
non_empty = [b for b in blocks if b.get("description", "").strip()]
|
||||
skipped = len(blocks) - len(non_empty)
|
||||
|
||||
new_descriptions = asyncio.run(_optimize_descriptions(non_empty))
|
||||
|
||||
stats = {
|
||||
"processed": len(non_empty),
|
||||
"success": len(new_descriptions),
|
||||
"failed": len(non_empty) - len(new_descriptions),
|
||||
"skipped": skipped,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Block description optimization complete: "
|
||||
"%d/%d succeeded, %d failed, %d skipped",
|
||||
stats["success"],
|
||||
stats["processed"],
|
||||
stats["failed"],
|
||||
stats["skipped"],
|
||||
)
|
||||
|
||||
if new_descriptions:
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
db_client.update_block_optimized_description(block_id, optimized)
|
||||
|
||||
# Update in-memory descriptions first so the cache rebuilds with fresh data.
|
||||
try:
|
||||
block_classes = get_blocks()
|
||||
for block_id, optimized in new_descriptions.items():
|
||||
if block_id in block_classes:
|
||||
block_classes[block_id]._optimized_description = optimized
|
||||
logger.info(
|
||||
"Updated %d in-memory block descriptions", len(new_descriptions)
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not update in-memory block descriptions", exc_info=True
|
||||
)
|
||||
|
||||
from backend.copilot.tools.agent_generator.blocks import (
|
||||
reset_block_caches, # local to avoid circular import
|
||||
)
|
||||
|
||||
reset_block_caches()
|
||||
|
||||
return stats
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Unit tests for optimize_blocks._optimize_descriptions."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.copilot.optimize_blocks import _RATE_LIMIT_DELAY, _optimize_descriptions
|
||||
|
||||
|
||||
def _make_client_response(text: str) -> MagicMock:
|
||||
"""Build a minimal mock that looks like an OpenAI ChatCompletion response."""
|
||||
choice = MagicMock()
|
||||
choice.message.content = text
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
return response
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
class TestOptimizeDescriptions:
|
||||
"""Tests for _optimize_descriptions async function."""
|
||||
|
||||
def test_returns_empty_when_no_client(self):
|
||||
with patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=None
|
||||
):
|
||||
result = _run(
|
||||
_optimize_descriptions([{"id": "b1", "name": "B", "description": "d"}])
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
def test_success_single_block(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("Short desc.")
|
||||
)
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {"b1": "Short desc."}
|
||||
client.chat.completions.create.assert_called_once()
|
||||
|
||||
def test_skips_block_on_exception(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(side_effect=Exception("API error"))
|
||||
blocks = [{"id": "b1", "name": "MyBlock", "description": "A block."}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.asyncio.sleep", new_callable=AsyncMock
|
||||
),
|
||||
):
|
||||
result = _run(_optimize_descriptions(blocks))
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_sleeps_between_blocks(self):
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_client_response("desc")
|
||||
)
|
||||
blocks = [
|
||||
{"id": "b1", "name": "B1", "description": "d1"},
|
||||
{"id": "b2", "name": "B2", "description": "d2"},
|
||||
]
|
||||
sleep_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.optimize_blocks.get_openai_client", return_value=client
|
||||
),
|
||||
patch("backend.copilot.optimize_blocks.asyncio.sleep", sleep_mock),
|
||||
):
|
||||
_run(_optimize_descriptions(blocks))
|
||||
|
||||
assert sleep_mock.call_count == 2
|
||||
sleep_mock.assert_called_with(_RATE_LIMIT_DELAY)
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
# Import here to allow module-level mocking if needed
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
# Minimal session mock
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_lock_shared():
|
||||
"""All parallel tools should receive the same lock instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_locks = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
observed_locks.append(lock)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_locks) == 3
|
||||
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
|
||||
assert isinstance(observed_locks[0], asyncio.Lock)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
@@ -1,218 +0,0 @@
|
||||
"""Centralized prompt building logic for CoPilot.
|
||||
|
||||
This module contains all prompt construction functions and constants,
|
||||
handling the distinction between:
|
||||
- SDK mode vs Baseline mode (tool documentation needs)
|
||||
- Local mode vs E2B mode (storage/filesystem differences)
|
||||
"""
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
working_dir: str,
|
||||
sandbox_type: str,
|
||||
storage_system_1_name: str,
|
||||
storage_system_1_characteristics: list[str],
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
Template function handles all formatting (bullets, indentation, markdown).
|
||||
Callers provide clean data as lists of strings.
|
||||
|
||||
Args:
|
||||
working_dir: Working directory path
|
||||
sandbox_type: Description of bash_exec sandbox
|
||||
storage_system_1_name: Name of primary storage (ephemeral or cloud)
|
||||
storage_system_1_characteristics: List of characteristic descriptions
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
persistence = "\n".join(f" - {p}" for p in storage_system_1_persistence)
|
||||
|
||||
return f"""
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
storage_system_1_name="Ephemeral working directory",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files here are **lost between turns** — do NOT rely on them persisting",
|
||||
"Use for temporary work: running scripts, processing data, etc.",
|
||||
],
|
||||
file_move_name_1_to_2="Ephemeral → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Ephemeral",
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
storage_system_1_name="Cloud sandbox",
|
||||
storage_system_1_characteristics=[
|
||||
"Shared by all file tools AND `bash_exec` — same filesystem",
|
||||
"Full Linux environment with internet access",
|
||||
],
|
||||
storage_system_1_persistence=[
|
||||
"Files **persist across turns** within the current session",
|
||||
"Lost when the session expires (12 h inactivity)",
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
)
|
||||
|
||||
|
||||
def _generate_tool_documentation() -> str:
|
||||
"""Auto-generate tool documentation from TOOL_REGISTRY.
|
||||
|
||||
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
|
||||
SDK mode doesn't need it since Claude gets tool schemas automatically.
|
||||
|
||||
This generates a complete list of available tools with their descriptions,
|
||||
ensuring the documentation stays in sync with the actual tool implementations.
|
||||
All workflow guidance is now embedded in individual tool descriptions.
|
||||
|
||||
Only documents tools that are available in the current environment
|
||||
(checked via tool.is_available property).
|
||||
"""
|
||||
docs = "\n## AVAILABLE TOOLS\n\n"
|
||||
|
||||
# Sort tools alphabetically for consistent output
|
||||
# Filter by is_available to match get_available_tools() behavior
|
||||
for name in sorted(TOOL_REGISTRY.keys()):
|
||||
tool = TOOL_REGISTRY[name]
|
||||
if not tool.is_available:
|
||||
continue
|
||||
schema = tool.as_openai_tool()
|
||||
desc = schema["function"].get("description", "No description available")
|
||||
# Format as bullet list with tool name in code style
|
||||
docs += f"- **`{name}`**: {desc}\n"
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
|
||||
"""Get the supplement for SDK mode (Claude Agent SDK).
|
||||
|
||||
SDK mode does NOT include tool documentation because Claude automatically
|
||||
receives tool schemas from the SDK. Only includes technical notes about
|
||||
storage systems and execution environment.
|
||||
|
||||
Args:
|
||||
use_e2b: Whether E2B cloud sandbox is being used
|
||||
cwd: Current working directory (only used in local_storage mode)
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
if use_e2b:
|
||||
return _get_cloud_sandbox_supplement()
|
||||
return _get_local_storage_supplement(cwd)
|
||||
|
||||
|
||||
def get_baseline_supplement() -> str:
|
||||
"""Get the supplement for baseline mode (direct OpenAI API).
|
||||
|
||||
Baseline mode INCLUDES auto-generated tool documentation because the
|
||||
direct API doesn't automatically provide tool schemas to Claude.
|
||||
Also includes shared technical notes (but NOT SDK-specific environment details).
|
||||
|
||||
Returns:
|
||||
The supplement string to append to the system prompt
|
||||
"""
|
||||
tool_docs = _generate_tool_documentation()
|
||||
return tool_docs + _SHARED_TOOL_NOTES
|
||||
@@ -5,17 +5,12 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
@@ -52,8 +47,7 @@ class StreamBaseResponse(BaseModel):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
json_str = self.model_dump_json(exclude_none=True)
|
||||
return f"data: {json_str}\n\n"
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
# ========== Message Lifecycle ==========
|
||||
@@ -64,13 +58,15 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
sessionId: str | None = Field(
|
||||
taskId: str | None = Field(
|
||||
default=None,
|
||||
description="Session ID for SSE reconnection.",
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
|
||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||
import json
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"messageId": self.messageId,
|
||||
@@ -151,9 +147,6 @@ class StreamToolInputAvailable(StreamBaseResponse):
|
||||
)
|
||||
|
||||
|
||||
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
|
||||
|
||||
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
@@ -168,12 +161,10 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Truncate oversized outputs after construction."""
|
||||
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
import json
|
||||
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
## Agent Generation Guide
|
||||
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "<UUID v4>", // auto-generated if omitted
|
||||
"version": 1,
|
||||
"is_active": true,
|
||||
"name": "Agent Name",
|
||||
"description": "What the agent does",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"block_id": "<block UUID from find_block>",
|
||||
"input_default": {
|
||||
"field_name": "design-time value"
|
||||
},
|
||||
"metadata": {
|
||||
"position": {"x": 0, "y": 0},
|
||||
"customized_name": "Optional display name"
|
||||
}
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "<UUID v4>",
|
||||
"source_id": "<source node UUID>",
|
||||
"source_name": "output_field_name",
|
||||
"sink_id": "<sink node UUID>",
|
||||
"sink_name": "input_field_name",
|
||||
"is_static": false
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### REQUIRED: AgentInputBlock and AgentOutputBlock
|
||||
|
||||
Every agent MUST include at least one AgentInputBlock and one AgentOutputBlock.
|
||||
These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentInputBlock** (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`):
|
||||
- Defines a user-facing input field on the agent
|
||||
- Required `input_default` fields: `name` (str), `value` (default: null)
|
||||
- Optional: `title`, `description`, `placeholder_values` (for dropdowns)
|
||||
- Output: `result` — the user-provided value at runtime
|
||||
- Create one AgentInputBlock per distinct input the agent needs
|
||||
|
||||
**AgentOutputBlock** (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`):
|
||||
- Defines a user-facing output displayed after the agent runs
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- The `value` input should be linked from another block's output
|
||||
- Optional: `title`, `description`, `format` (Jinja2 template)
|
||||
- Create one AgentOutputBlock per distinct result to show the user
|
||||
|
||||
Without these blocks, the agent has no interface and the user cannot provide
|
||||
inputs or see outputs. NEVER skip them.
|
||||
|
||||
### Key Rules
|
||||
|
||||
- **Name & description**: Include `name` and `description` in the agent JSON
|
||||
when creating a new agent, or when editing and the agent's purpose changed.
|
||||
Without these the agent gets a generic default name.
|
||||
- **Design-time vs runtime**: `input_default` = values known at build time.
|
||||
For user-provided values, create an `AgentInputBlock` node and link its
|
||||
output to the consuming block's input.
|
||||
- **Credentials**: Do NOT require credentials upfront. Users configure
|
||||
credentials later in the platform UI after the agent is saved.
|
||||
- **Node spacing**: Position nodes with at least 800 X-units between them.
|
||||
- **Nested properties**: Use `parentField_#_childField` notation in link
|
||||
sink_name/source_name to access nested object fields.
|
||||
- **is_static links**: Set `is_static: true` when the link carries a
|
||||
design-time constant (matches a field in inputSchema with a default).
|
||||
- **ConditionBlock**: Needs a `StoreValueBlock` wired to its `value2` input.
|
||||
- **Prompt templates**: Use `{{variable}}` (double curly braces) for
|
||||
literal braces in prompt strings — single `{` and `}` are for
|
||||
template variables.
|
||||
- **AgentExecutorBlock**: When composing sub-agents, set `graph_id` and
|
||||
`graph_version` in input_default, and wire inputs/outputs to match
|
||||
the sub-agent's schema.
|
||||
|
||||
### Using Sub-Agents (AgentExecutorBlock)
|
||||
|
||||
To compose agents using other agents as sub-agents:
|
||||
1. Call `find_library_agent` to find the sub-agent — the response includes
|
||||
`graph_id`, `graph_version`, `input_schema`, and `output_schema`
|
||||
2. Create an `AgentExecutorBlock` node (ID: `e189baac-8c20-45a1-94a7-55177ea42565`)
|
||||
3. Set `input_default`:
|
||||
- `graph_id`: from the library agent's `graph_id`
|
||||
- `graph_version`: from the library agent's `graph_version`
|
||||
- `input_schema`: from the library agent's `input_schema` (JSON Schema)
|
||||
- `output_schema`: from the library agent's `output_schema` (JSON Schema)
|
||||
- `user_id`: leave as `""` (filled at runtime)
|
||||
- `inputs`: `{}` (populated by links at runtime)
|
||||
4. Wire inputs: link to sink names matching the sub-agent's `input_schema`
|
||||
property names (e.g., if input_schema has a `"url"` property, use
|
||||
`"url"` as the sink_name)
|
||||
5. Wire outputs: link from source names matching the sub-agent's
|
||||
`output_schema` property names
|
||||
6. Pass `library_agent_ids` to `create_agent`/`customize_agent` with
|
||||
the library agent IDs used, so the fixer can validate schemas
|
||||
|
||||
### Using MCP Tools (MCPToolBlock)
|
||||
|
||||
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
1. The user must specify which MCP server URL and tool name they want
|
||||
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||
3. Set `input_default`:
|
||||
- `server_url`: the MCP server URL (e.g. `"https://mcp.example.com/sse"`)
|
||||
- `selected_tool`: the tool name on that server
|
||||
- `tool_input_schema`: JSON Schema for the tool's inputs
|
||||
- `tool_arguments`: `{}` (populated by links or hardcoded values)
|
||||
4. The block requires MCP credentials — the user configures these in the
|
||||
platform UI after the agent is saved
|
||||
5. Wire inputs using the tool argument field name directly as the sink_name
|
||||
(e.g., `query`, NOT `tool_arguments_#_query`). The execution engine
|
||||
automatically collects top-level fields matching tool_input_schema into
|
||||
tool_arguments.
|
||||
6. Output: `result` (the tool's return value) and `error` (error message)
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
A minimal agent with input, processing, and output:
|
||||
- Node 1: `AgentInputBlock` (ID: `c0a8e994-ebf1-4a9c-a4d8-89d09c86741b`,
|
||||
input_default: {"name": "user_text", "title": "Text to process"},
|
||||
output: "result")
|
||||
- Node 2: `AITextGeneratorBlock` (input: "prompt" linked from Node 1's "result")
|
||||
- Node 3: `AgentOutputBlock` (ID: `363ae599-353e-4804-937e-b2ee3cef3da4`,
|
||||
input_default: {"name": "summary", "title": "Summary"},
|
||||
input: "value" linked from Node 2's output)
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Compaction tracking for SDK-based chat sessions.
|
||||
|
||||
Encapsulates the state machine and event emission for context compaction,
|
||||
both pre-query (history compressed before SDK query) and SDK-internal
|
||||
(PreCompact hook fires mid-stream).
|
||||
|
||||
All compaction-related helpers live here: event builders, message filtering,
|
||||
persistence, and the ``CompactionTracker`` state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event builders (private — use CompactionTracker or compaction_events)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _start_events(tool_call_id: str) -> list[StreamBaseResponse]:
|
||||
"""Build the opening events for a compaction tool call."""
|
||||
return [
|
||||
StreamStartStep(),
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME),
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME, input={}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _end_events(tool_call_id: str, message: str) -> list[StreamBaseResponse]:
|
||||
"""Build the closing events for a compaction tool call."""
|
||||
return [
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=COMPACTION_TOOL_NAME,
|
||||
output=message,
|
||||
),
|
||||
StreamFinishStep(),
|
||||
]
|
||||
|
||||
|
||||
def _new_tool_call_id() -> str:
|
||||
return f"compaction-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public event builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def emit_compaction(session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Create, persist, and return a self-contained compaction tool call.
|
||||
|
||||
Convenience for callers that don't use ``CompactionTracker`` (e.g. the
|
||||
legacy non-SDK streaming path in ``service.py``).
|
||||
"""
|
||||
tc_id = _new_tool_call_id()
|
||||
evts = compaction_events(COMPACTION_DONE_MSG, tool_call_id=tc_id)
|
||||
_persist(session, tc_id, COMPACTION_DONE_MSG)
|
||||
return evts
|
||||
|
||||
|
||||
def compaction_events(
|
||||
message: str, tool_call_id: str | None = None
|
||||
) -> list[StreamBaseResponse]:
|
||||
"""Emit a self-contained compaction tool call (already completed).
|
||||
|
||||
When *tool_call_id* is provided it is reused (e.g. for persistence that
|
||||
must match an already-streamed start event). Otherwise a new ID is
|
||||
generated.
|
||||
"""
|
||||
tc_id = tool_call_id or _new_tool_call_id()
|
||||
return _start_events(tc_id) + _end_events(tc_id, message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def filter_compaction_messages(
|
||||
messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Remove synthetic compaction tool-call messages (UI-only artifacts).
|
||||
|
||||
Strips assistant messages whose only tool calls are compaction calls,
|
||||
and their corresponding tool-result messages.
|
||||
"""
|
||||
compaction_ids: set[str] = set()
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
continue
|
||||
filtered.append(msg)
|
||||
return filtered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _persist(session: ChatSession, tool_call_id: str, message: str) -> None:
|
||||
"""Append compaction tool-call + result to session messages.
|
||||
|
||||
Compaction events are synthetic so they bypass the normal adapter
|
||||
accumulation. This explicitly records them so they survive a page refresh.
|
||||
"""
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPACTION_TOOL_NAME,
|
||||
"arguments": "{}",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content=message, tool_call_id=tool_call_id)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompactionTracker — state machine for streaming sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CompactionTracker:
|
||||
"""Tracks compaction state and yields UI events.
|
||||
|
||||
Two compaction paths:
|
||||
|
||||
1. **Pre-query** — history compressed before the SDK query starts.
|
||||
Call :meth:`emit_pre_query` to yield a self-contained tool call.
|
||||
|
||||
2. **SDK-internal** — ``PreCompact`` hook fires mid-stream.
|
||||
Call :meth:`emit_start_if_ready` on heartbeat ticks and
|
||||
:meth:`emit_end_if_ready` when a message arrives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._compact_start = asyncio.Event()
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@property
|
||||
def on_compact(self) -> Callable[[], None]:
|
||||
"""Callback for the PreCompact hook."""
|
||||
return self._compact_start.set
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Emit + persist a self-contained compaction tool call."""
|
||||
self._done = True
|
||||
return emit_compaction(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SDK-internal compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
if self._compact_start.is_set() and not self._start_emitted and not self._done:
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = True
|
||||
self._tool_call_id = _new_tool_call_id()
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""If compaction is in progress, emit end events and persist."""
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return []
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return []
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
|
||||
persist_id = self._tool_call_id
|
||||
else:
|
||||
# PreCompact fired but start never emitted — self-contained
|
||||
persist_id = _new_tool_call_id()
|
||||
done_events = compaction_events(
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return done_events
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Tests for sdk/compaction.py — event builders, filtering, persistence, and
|
||||
CompactionTracker state machine."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamStartStep,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.compaction import (
|
||||
CompactionTracker,
|
||||
compaction_events,
|
||||
emit_compaction,
|
||||
filter_compaction_messages,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compaction_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionEvents:
|
||||
def test_returns_start_and_end_events(self):
|
||||
evts = compaction_events("done")
|
||||
assert len(evts) == 5
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert isinstance(evts[1], StreamToolInputStart)
|
||||
assert isinstance(evts[2], StreamToolInputAvailable)
|
||||
assert isinstance(evts[3], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[4], StreamFinishStep)
|
||||
|
||||
def test_uses_provided_tool_call_id(self):
|
||||
evts = compaction_events("msg", tool_call_id="my-id")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolCallId == "my-id"
|
||||
|
||||
def test_generates_id_when_not_provided(self):
|
||||
evts = compaction_events("msg")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolCallId.startswith("compaction-")
|
||||
|
||||
def test_tool_name_is_context_compaction(self):
|
||||
evts = compaction_events("msg")
|
||||
tool_start = evts[1]
|
||||
assert isinstance(tool_start, StreamToolInputStart)
|
||||
assert tool_start.toolName == COMPACTION_TOOL_NAME
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# emit_compaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitCompaction:
|
||||
def test_persists_to_session(self):
|
||||
session = _make_session()
|
||||
assert len(session.messages) == 0
|
||||
evts = emit_compaction(session)
|
||||
assert len(evts) == 5
|
||||
# Should have appended 2 messages (assistant tool call + tool result)
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[0].role == "assistant"
|
||||
assert session.messages[0].tool_calls is not None
|
||||
assert (
|
||||
session.messages[0].tool_calls[0]["function"]["name"]
|
||||
== COMPACTION_TOOL_NAME
|
||||
)
|
||||
assert session.messages[1].role == "tool"
|
||||
assert session.messages[1].content == COMPACTION_DONE_MSG
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# filter_compaction_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFilterCompactionMessages:
|
||||
def test_removes_compaction_tool_calls(self):
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "comp-1",
|
||||
"type": "function",
|
||||
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool", content=COMPACTION_DONE_MSG, tool_call_id="comp-1"
|
||||
),
|
||||
ChatMessage(role="assistant", content="world"),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0].content == "hello"
|
||||
assert filtered[1].content == "world"
|
||||
|
||||
def test_keeps_non_compaction_tool_calls(self):
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "real-1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content="result", tool_call_id="real-1"),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 2
|
||||
|
||||
def test_keeps_assistant_with_content_and_compaction_call(self):
|
||||
"""If assistant message has both content and a compaction tool call,
|
||||
the message is kept (has real content)."""
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I have content",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "comp-1",
|
||||
"type": "function",
|
||||
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == 1
|
||||
|
||||
def test_empty_list(self):
|
||||
assert filter_compaction_messages([]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompactionTracker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionTracker:
|
||||
def test_on_compact_sets_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
assert tracker._compact_start.is_set()
|
||||
|
||||
def test_emit_start_if_ready_no_event(self):
|
||||
tracker = CompactionTracker()
|
||||
assert tracker.emit_start_if_ready() == []
|
||||
|
||||
def test_emit_start_if_ready_with_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert isinstance(evts[1], StreamToolInputStart)
|
||||
assert isinstance(evts[2], StreamToolInputAvailable)
|
||||
|
||||
def test_emit_start_only_once(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
evts1 = tracker.emit_start_if_ready()
|
||||
assert len(evts1) == 3
|
||||
# Second call should return empty
|
||||
evts2 = tracker.emit_start_if_ready()
|
||||
assert evts2 == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_after_start(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 2
|
||||
assert isinstance(evts[0], StreamToolOutputAvailable)
|
||||
assert isinstance(evts[1], StreamFinishStep)
|
||||
# Should persist
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_without_start_self_contained(self):
|
||||
"""If PreCompact fired but start was never emitted, emit_end
|
||||
produces a self-contained compaction event."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
# Don't call emit_start_if_ready
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert len(evts) == 5 # Full self-contained event
|
||||
assert isinstance(evts[0], StreamStartStep)
|
||||
assert len(session.messages) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_done(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# Second call should be no-op
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_no_op_when_nothing_happened(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = await tracker.emit_end_if_ready(session)
|
||||
assert evts == []
|
||||
|
||||
def test_emit_pre_query(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
evts = tracker.emit_pre_query(session)
|
||||
assert len(evts) == 5
|
||||
assert len(session.messages) == 2
|
||||
assert tracker._done is True
|
||||
|
||||
def test_reset_for_query(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction(self):
|
||||
"""After pre-query compaction, SDK compaction events are suppressed."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == [] # _done blocks it
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
"""After reset_for_query, compaction can fire again."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3 # Start events emitted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_id_consistency(self):
|
||||
"""Start and end events use the same tool_call_id."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact()
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
end_evts = await tracker.emit_end_if_ready(session)
|
||||
start_evt = start_evts[1]
|
||||
end_evt = end_evts[0]
|
||||
assert isinstance(start_evt, StreamToolInputStart)
|
||||
assert isinstance(end_evt, StreamToolOutputAvailable)
|
||||
assert start_evt.toolCallId == end_evt.toolCallId
|
||||
# Persisted ID should also match
|
||||
tool_calls = session.messages[0].tool_calls
|
||||
assert tool_calls is not None
|
||||
assert tool_calls[0]["id"] == start_evt.toolCallId
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
Returns a simple streaming response with text deltas to test:
|
||||
- Streaming infrastructure works
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -1,352 +0,0 @@
|
||||
"""MCP file-tool handlers that route to the E2B cloud sandbox.
|
||||
|
||||
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
|
||||
Glob/Grep so that all file operations share the same ``/home/user``
|
||||
filesystem as ``bash_exec``.
|
||||
|
||||
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
|
||||
def _is_allowed_local(path: str) -> bool:
|
||||
return is_allowed_local_path(path, get_sdk_cwd())
|
||||
|
||||
|
||||
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
|
||||
if error:
|
||||
text = json.dumps({"error": text, "type": "error"})
|
||||
return {"content": [{"type": "text", "text": text}], "isError": error}
|
||||
|
||||
|
||||
def _get_sandbox_and_path(
|
||||
file_path: str,
|
||||
) -> tuple[Any, str] | dict[str, Any]:
|
||||
"""Common preamble: get sandbox + resolve path, or return MCP error."""
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
try:
|
||||
remote = resolve_sandbox_path(file_path)
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
# Tool handlers
|
||||
|
||||
|
||||
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
lines = content.splitlines(keepends=True)
|
||||
selected = list(itertools.islice(lines, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
|
||||
|
||||
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Write content to a sandbox file, creating parent directories as needed."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
content: str = args.get("content", "")
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Successfully wrote to {remote}")
|
||||
|
||||
|
||||
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Replace a substring in a sandbox file, with optional replace-all support."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
old_string: str = args.get("old_string", "")
|
||||
new_string: str = args.get("new_string", "")
|
||||
replace_all: bool = args.get("replace_all", False)
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
if not old_string:
|
||||
return _mcp("old_string is required", error=True)
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to read {remote}: {exc}", error=True)
|
||||
|
||||
count = content.count(old_string)
|
||||
if count == 0:
|
||||
return _mcp(f"old_string not found in {file_path}", error=True)
|
||||
if count > 1 and not replace_all:
|
||||
return _mcp(
|
||||
f"old_string appears {count} times in {file_path}. "
|
||||
"Use replace_all=true or provide a more unique string.",
|
||||
error=True,
|
||||
)
|
||||
|
||||
updated = (
|
||||
content.replace(old_string, new_string)
|
||||
if replace_all
|
||||
else content.replace(old_string, new_string, 1)
|
||||
)
|
||||
try:
|
||||
await sandbox.files.write(remote, updated)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
|
||||
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
|
||||
|
||||
|
||||
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Find files matching a name pattern inside the sandbox using ``find``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Glob failed: {exc}", error=True)
|
||||
|
||||
files = [line for line in (result.stdout or "").strip().splitlines() if line]
|
||||
return _mcp(json.dumps(files, indent=2))
|
||||
|
||||
|
||||
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Search file contents by regex inside the sandbox using ``grep -rn``."""
|
||||
pattern: str = args.get("pattern", "")
|
||||
path: str = args.get("path", "")
|
||||
include: str = args.get("include", "")
|
||||
|
||||
if not pattern:
|
||||
return _mcp("pattern is required", error=True)
|
||||
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is None:
|
||||
return _mcp("No E2B sandbox available", error=True)
|
||||
|
||||
try:
|
||||
search_dir = resolve_sandbox_path(path) if path else E2B_WORKDIR
|
||||
except ValueError as exc:
|
||||
return _mcp(str(exc), error=True)
|
||||
|
||||
parts = ["grep", "-rn", "--color=never"]
|
||||
if include:
|
||||
parts.extend(["--include", include])
|
||||
parts.extend([pattern, search_dir])
|
||||
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Grep failed: {exc}", error=True)
|
||||
|
||||
output = (result.stdout or "").strip()
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
|
||||
"""Read from the host filesystem (defence-in-depth path check)."""
|
||||
if not _is_allowed_local(file_path):
|
||||
return _mcp(f"Path not allowed: {file_path}", error=True)
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(expanded, encoding="utf-8", errors="replace") as fh:
|
||||
selected = list(itertools.islice(fh, offset, offset + limit))
|
||||
numbered = "".join(
|
||||
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
|
||||
)
|
||||
return _mcp(numbered)
|
||||
except FileNotFoundError:
|
||||
return _mcp(f"File not found: {file_path}", error=True)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Error reading {file_path}: {exc}", error=True)
|
||||
|
||||
|
||||
# Tool descriptors (name, description, schema, handler)
|
||||
|
||||
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
|
||||
(
|
||||
"read_file",
|
||||
"Read a file from the cloud sandbox (/home/user). "
|
||||
"Use offset and limit for large files.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line to start reading from (0-indexed). Default: 0.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000.",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
_handle_read_file,
|
||||
),
|
||||
(
|
||||
"write_file",
|
||||
"Write or create a file in the cloud sandbox (/home/user). "
|
||||
"Parent directories are created automatically. "
|
||||
"To copy a workspace file into the sandbox, use "
|
||||
"read_workspace_file with save_to_path instead.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"content": {"type": "string", "description": "Content to write."},
|
||||
},
|
||||
"required": ["file_path", "content"],
|
||||
},
|
||||
_handle_write_file,
|
||||
),
|
||||
(
|
||||
"edit_file",
|
||||
"Targeted text replacement in a sandbox file. "
|
||||
"old_string must appear in the file and is replaced with new_string.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path (relative to /home/user, or absolute).",
|
||||
},
|
||||
"old_string": {"type": "string", "description": "Text to find."},
|
||||
"new_string": {"type": "string", "description": "Replacement text."},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default: false).",
|
||||
},
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"],
|
||||
},
|
||||
_handle_edit_file,
|
||||
),
|
||||
(
|
||||
"glob",
|
||||
"Search for files by name pattern in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern (e.g. *.py).",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search. Default: /home/user.",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_glob,
|
||||
),
|
||||
(
|
||||
"grep",
|
||||
"Search file contents by regex in the cloud sandbox.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern."},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory. Default: /home/user.",
|
||||
},
|
||||
"include": {
|
||||
"type": "string",
|
||||
"description": "Glob to filter files (e.g. *.py).",
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
_handle_grep,
|
||||
),
|
||||
]
|
||||
|
||||
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]
|
||||
@@ -1,154 +0,0 @@
|
||||
"""Tests for E2B file-tool path validation and local read safety.
|
||||
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def test_read_tool_results_file(self):
|
||||
"""Reading a tool-results file should succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read"
|
||||
filepath = self._make_tool_results_file(
|
||||
encoded, "result.txt", "line 1\nline 2\nline 3\n"
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
def test_read_arbitrary_host_path_blocked(self):
|
||||
"""Arbitrary host paths are blocked even if they exist."""
|
||||
result = _read_local("/proc/self/environ", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
def test_read_with_offset_and_limit(self):
|
||||
"""Offset and limit should control which lines are returned."""
|
||||
encoded = "-tmp-copilot-e2b-test-offset"
|
||||
content = "".join(f"line {i}\n" for i in range(10))
|
||||
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=3, limit=2)
|
||||
assert result["isError"] is False
|
||||
text = result["content"][0]["text"]
|
||||
assert "line 3" in text
|
||||
assert "line 4" in text
|
||||
assert "line 2" not in text
|
||||
assert "line 5" not in text
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_without_project_dir_blocks_all(self):
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
@@ -1,281 +0,0 @@
|
||||
"""File reference protocol for tool call inputs.
|
||||
|
||||
Allows the LLM to pass a file reference instead of embedding large content
|
||||
inline. The processor expands ``@@agptfile:<uri>[<start>-<end>]`` tokens in tool
|
||||
arguments before the tool is executed.
|
||||
|
||||
Protocol
|
||||
--------
|
||||
|
||||
@@agptfile:<uri>[<start>-<end>]
|
||||
|
||||
``<uri>`` (required)
|
||||
- ``workspace://<file_id>`` — workspace file by ID
|
||||
- ``workspace://<file_id>#<mime>`` — same, MIME hint is ignored for reads
|
||||
- ``workspace:///<path>`` — workspace file by virtual path
|
||||
- ``/absolute/local/path`` — ephemeral or sdk_cwd file (validated by
|
||||
:func:`~backend.copilot.sdk.tool_adapter.is_allowed_local_path`)
|
||||
- Any absolute path that resolves inside the E2B sandbox
|
||||
(``/home/user/...``) when a sandbox is active
|
||||
|
||||
``[<start>-<end>]`` (optional)
|
||||
Line range, 1-indexed inclusive. Examples: ``[1-100]``, ``[50-200]``.
|
||||
Omit to read the entire file.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@agptfile:workspace://abc123
|
||||
@@agptfile:workspace://abc123[10-50]
|
||||
@@agptfile:workspace:///reports/q1.md
|
||||
@@agptfile:/tmp/copilot-<session>/output.py[1-80]
|
||||
@@agptfile:/home/user/script.sh
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
"""Raised when a ``@@agptfile:`` reference in tool call args fails to resolve.
|
||||
|
||||
Separating this from inline substitution lets callers (e.g. the MCP tool
|
||||
wrapper) block tool execution and surface a helpful error to the model
|
||||
rather than passing an ``[file-ref error: …]`` string as actual input.
|
||||
"""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_REF_PREFIX = "@@agptfile:"
|
||||
|
||||
# Matches: @@agptfile:<uri>[start-end]?
|
||||
# Group 1 – URI; must start with '/' (absolute path) or 'workspace://'
|
||||
# Group 2 – start line (optional)
|
||||
# Group 3 – end line (optional)
|
||||
_FILE_REF_RE = re.compile(
|
||||
re.escape(FILE_REF_PREFIX) + r"((?:workspace://|/)[^\[\s]*)(?:\[(\d+)-(\d+)\])?"
|
||||
)
|
||||
|
||||
# Maximum characters returned for a single file reference expansion.
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileRef:
|
||||
uri: str
|
||||
start_line: int | None # 1-indexed, inclusive
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
A "bare token" means the entire string matches the ``@@agptfile:...`` pattern
|
||||
(after stripping whitespace). Use :func:`expand_file_refs_in_string` to
|
||||
expand references embedded in larger strings.
|
||||
"""
|
||||
m = _FILE_REF_RE.fullmatch(text.strip())
|
||||
if not m:
|
||||
return None
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if start is not None and start < 1:
|
||||
return None
|
||||
if end is not None and end < 1:
|
||||
return None
|
||||
if start is not None and end is not None and end < start:
|
||||
return None
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> bytes:
|
||||
"""Resolve *uri* to raw bytes using workspace, local, or E2B path logic.
|
||||
|
||||
Raises :class:`ValueError` if the URI cannot be resolved.
|
||||
"""
|
||||
# Strip MIME fragment (e.g. workspace://id#mime) before dispatching.
|
||||
plain = uri.split("#")[0] if uri.startswith("workspace://") else uri
|
||||
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
with open(resolved, "rb") as fh:
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
try:
|
||||
remote = resolve_sandbox_path(plain)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_file_ref(
|
||||
ref: FileRef,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
"""Expand all ``@@agptfile:...`` tokens in *text*, returning the substituted string.
|
||||
|
||||
Non-reference text is passed through unchanged.
|
||||
|
||||
If *raise_on_error* is ``False`` (default), expansion errors are surfaced
|
||||
inline as ``[file-ref error: <message>]`` — useful for display/log contexts
|
||||
where partial expansion is acceptable.
|
||||
|
||||
If *raise_on_error* is ``True``, any resolution failure raises
|
||||
:class:`FileRefExpansionError` immediately so the caller can block the
|
||||
operation and surface a clean error to the model.
|
||||
"""
|
||||
if FILE_REF_PREFIX not in text:
|
||||
return text
|
||||
|
||||
result: list[str] = []
|
||||
last_end = 0
|
||||
total_chars = 0
|
||||
for m in _FILE_REF_RE.finditer(text):
|
||||
result.append(text[last_end : m.start()])
|
||||
start = int(m.group(2)) if m.group(2) else None
|
||||
end = int(m.group(3)) if m.group(3) else None
|
||||
if (start is not None and start < 1) or (end is not None and end < 1):
|
||||
msg = f"line numbers must be >= 1: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
if start is not None and end is not None and end < start:
|
||||
msg = f"end line must be >= start line: {m.group(0)}"
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(msg)
|
||||
result.append(f"[file-ref error: {msg}]")
|
||||
last_end = m.end()
|
||||
continue
|
||||
ref = FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
try:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
content = content[:remaining] + "\n... [total budget exhausted]"
|
||||
total_chars += len(content)
|
||||
result.append(content)
|
||||
except ValueError as exc:
|
||||
logger.warning("file-ref expansion failed for %r: %s", m.group(0), exc)
|
||||
if raise_on_error:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
result.append(f"[file-ref error: {exc}]")
|
||||
last_end = m.end()
|
||||
|
||||
result.append(text[last_end:])
|
||||
return "".join(result)
|
||||
|
||||
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
response that lets the model correct the reference before retrying.
|
||||
"""
|
||||
if not args:
|
||||
return args
|
||||
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
@@ -1,328 +0,0 @@
|
||||
"""Integration tests for @@agptfile: reference expansion in tool calls.
|
||||
|
||||
These tests verify the end-to-end behaviour of the file reference protocol:
|
||||
- Parsing @@agptfile: tokens from tool arguments
|
||||
- Resolving local-filesystem paths (sdk_cwd / ephemeral)
|
||||
- Expanding references inside the tool-call pipeline (_execute_tool_sync)
|
||||
- The extended Read tool handler (workspace:// pass-through via session context)
|
||||
|
||||
No real LLM or database is required; workspace reads are stubbed where needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
FileRef,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
read_file_bytes,
|
||||
resolve_file_ref,
|
||||
)
|
||||
from backend.copilot.sdk.tool_adapter import _read_file_handler
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "integ-sess") -> MagicMock:
|
||||
s = MagicMock()
|
||||
s.session_id = session_id
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local-file resolution (sdk_cwd)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path():
|
||||
"""resolve_file_ref reads a real local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
# Write a test file inside sdk_cwd
|
||||
test_file = os.path.join(sdk_cwd, "hello.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("line1\nline2\nline3\n")
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=None, end_line=None)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_local_path_with_line_range():
|
||||
"""resolve_file_ref respects line ranges for local files."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "multi.txt")
|
||||
lines = [f"line{i}\n" for i in range(1, 11)] # line1 … line10
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
session = _make_session()
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
ref = FileRef(uri=test_file, start_line=3, end_line=5)
|
||||
content = await resolve_file_ref(ref, user_id="u1", session=session)
|
||||
|
||||
assert content == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_file_ref_rejects_path_outside_sdk_cwd():
|
||||
"""resolve_file_ref raises ValueError for paths outside sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_sandbox_var.get.return_value = None
|
||||
|
||||
ref = FileRef(uri="/etc/passwd", start_line=None, end_line=None)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await resolve_file_ref(ref, user_id="u1", session=_make_session())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string — integration with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_with_real_file():
|
||||
"""expand_file_refs_in_string replaces @@agptfile: token with actual content."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "data.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("hello world\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"Content: @@agptfile:{test_file}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result == "Content: hello world\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_string_missing_file_is_surfaced_inline():
|
||||
"""Missing file ref yields [file-ref error: …] inline rather than raising."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
missing = os.path.join(sdk_cwd, "does_not_exist.txt")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_string(
|
||||
f"@@agptfile:{missing}",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "[file-ref error:" in result
|
||||
assert "not found" in result.lower() or "not allowed" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — dict traversal with real files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
"""Nested @@agptfile: references in args are fully expanded."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
file_a = os.path.join(sdk_cwd, "a.txt")
|
||||
file_b = os.path.join(sdk_cwd, "b.txt")
|
||||
with open(file_a, "w") as f:
|
||||
f.write("AAA")
|
||||
with open(file_b, "w") as f:
|
||||
f.write("BBB")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"outer": {
|
||||
"content_a": f"@@agptfile:{file_a}",
|
||||
"content_b": f"start @@agptfile:{file_b} end",
|
||||
},
|
||||
"count": 42,
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["outer"]["content_a"] == "AAA"
|
||||
assert result["outer"]["content_b"] == "start BBB end"
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_local_file():
|
||||
"""_read_file_handler reads a local file when it's within sdk_cwd."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
test_file = os.path.join(sdk_cwd, "read_test.txt")
|
||||
lines = [f"L{i}\n" for i in range(1, 6)]
|
||||
with open(test_file, "w") as f:
|
||||
f.writelines(lines)
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj_var, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
mock_proj_var.get.return_value = ""
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": test_file, "offset": 0, "limit": 5}
|
||||
)
|
||||
|
||||
assert not result["isError"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "L1" in text
|
||||
assert "L5" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri():
|
||||
"""_read_file_handler handles workspace:// URIs via the workspace manager."""
|
||||
mock_session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id.return_value = b"workspace file content\nline two\n"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
{"file_path": "workspace://file-id-abc", "offset": 0, "limit": 10}
|
||||
)
|
||||
|
||||
assert not result["isError"], result["content"][0]["text"]
|
||||
text = result["content"][0]["text"]
|
||||
assert "workspace file content" in text
|
||||
assert "line two" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_workspace_uri_no_session():
|
||||
"""_read_file_handler returns error when workspace:// is used without session."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = await _read_file_handler({"file_path": "workspace://some-id"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "session" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_handler_access_denied():
|
||||
"""_read_file_handler rejects paths outside allowed locations."""
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox, patch(
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", _make_session()),
|
||||
):
|
||||
mock_cwd.get.return_value = "/tmp/safe-dir"
|
||||
mock_sandbox.get.return_value = None
|
||||
|
||||
result = await _read_file_handler({"file_path": "/etc/passwd"})
|
||||
|
||||
assert result["isError"]
|
||||
assert "not allowed" in result["content"][0]["text"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_file_bytes — workspace:///path (virtual path) and E2B sandbox branch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_workspace_virtual_path():
|
||||
"""workspace:///path resolves via manager.read_file (is_path=True path)."""
|
||||
session = _make_session()
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
assert result == b"virtual path content"
|
||||
mock_manager.read_file.assert_awaited_once_with("/reports/q1.md")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_sandbox_branch():
|
||||
"""read_file_bytes reads from the E2B sandbox when a sandbox is active."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
mock_sandbox.files.read.return_value = bytearray(b"sandbox content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
result = await read_file_bytes("/home/user/script.sh", None, session)
|
||||
|
||||
assert result == b"sandbox content"
|
||||
mock_sandbox.files.read.assert_awaited_once_with(
|
||||
"/home/user/script.sh", format="bytes"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_bytes_e2b_path_escapes_sandbox_raises():
|
||||
"""read_file_bytes raises ValueError for paths that escape the sandbox root."""
|
||||
session = _make_session()
|
||||
mock_sandbox = AsyncMock()
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd, patch(
|
||||
"backend.copilot.context._current_sandbox"
|
||||
) as mock_sandbox_var, patch(
|
||||
"backend.copilot.context._current_project_dir"
|
||||
) as mock_proj:
|
||||
mock_cwd.get.return_value = ""
|
||||
mock_sandbox_var.get.return_value = mock_sandbox
|
||||
mock_proj.get.return_value = ""
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
await read_file_bytes("/etc/passwd", None, session)
|
||||
@@ -1,382 +0,0 @@
|
||||
"""Tests for the @@agptfile: reference protocol (file_ref.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.file_ref import (
|
||||
_MAX_EXPAND_CHARS,
|
||||
FileRef,
|
||||
FileRefExpansionError,
|
||||
_apply_line_range,
|
||||
expand_file_refs_in_args,
|
||||
expand_file_refs_in_string,
|
||||
parse_file_ref,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_id():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123")
|
||||
assert ref == FileRef(uri="workspace://abc123", start_line=None, end_line=None)
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_id_with_mime():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123#text/plain")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace://abc123#text/plain"
|
||||
assert ref.start_line is None
|
||||
|
||||
|
||||
def test_parse_file_ref_workspace_path():
|
||||
ref = parse_file_ref("@@agptfile:workspace:///reports/q1.md")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace:///reports/q1.md"
|
||||
|
||||
|
||||
def test_parse_file_ref_with_line_range():
|
||||
ref = parse_file_ref("@@agptfile:workspace://abc123[10-50]")
|
||||
assert ref == FileRef(uri="workspace://abc123", start_line=10, end_line=50)
|
||||
|
||||
|
||||
def test_parse_file_ref_local_path():
|
||||
ref = parse_file_ref("@@agptfile:/tmp/copilot-session/output.py[1-100]")
|
||||
assert ref is not None
|
||||
assert ref.uri == "/tmp/copilot-session/output.py"
|
||||
assert ref.start_line == 1
|
||||
assert ref.end_line == 100
|
||||
|
||||
|
||||
def test_parse_file_ref_no_match():
|
||||
assert parse_file_ref("just a normal string") is None
|
||||
assert parse_file_ref("workspace://abc123") is None # missing @@agptfile: prefix
|
||||
assert (
|
||||
parse_file_ref("@@agptfile:workspace://abc123 extra") is None
|
||||
) # not full match
|
||||
|
||||
|
||||
def test_parse_file_ref_strips_whitespace():
|
||||
ref = parse_file_ref(" @@agptfile:workspace://abc123 ")
|
||||
assert ref is not None
|
||||
assert ref.uri == "workspace://abc123"
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_zero_start():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[0-5]") is None
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_end_less_than_start():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[10-5]") is None
|
||||
|
||||
|
||||
def test_parse_file_ref_invalid_range_zero_end():
|
||||
assert parse_file_ref("@@agptfile:workspace://abc123[1-0]") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _apply_line_range
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
TEXT = "line1\nline2\nline3\nline4\nline5\n"
|
||||
|
||||
|
||||
def test_apply_line_range_no_range():
|
||||
assert _apply_line_range(TEXT, None, None) == TEXT
|
||||
|
||||
|
||||
def test_apply_line_range_start_only():
|
||||
result = _apply_line_range(TEXT, 3, None)
|
||||
assert result == "line3\nline4\nline5\n"
|
||||
|
||||
|
||||
def test_apply_line_range_full():
|
||||
result = _apply_line_range(TEXT, 2, 4)
|
||||
assert result == "line2\nline3\nline4\n"
|
||||
|
||||
|
||||
def test_apply_line_range_single_line():
|
||||
result = _apply_line_range(TEXT, 2, 2)
|
||||
assert result == "line2\n"
|
||||
|
||||
|
||||
def test_apply_line_range_beyond_eof():
|
||||
result = _apply_line_range(TEXT, 4, 999)
|
||||
assert result == "line4\nline5\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(session_id: str = "sess-1") -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.session_id = session_id
|
||||
return session
|
||||
|
||||
|
||||
async def _resolve_always(ref: FileRef, _user_id: str | None, _session: object) -> str:
|
||||
"""Stub resolver that returns the URI and range as a descriptive string."""
|
||||
if ref.start_line is not None:
|
||||
return f"content:{ref.uri}[{ref.start_line}-{ref.end_line}]"
|
||||
return f"content:{ref.uri}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_no_refs():
|
||||
result = await expand_file_refs_in_string(
|
||||
"no references here", user_id="u1", session=_make_session()
|
||||
)
|
||||
assert result == "no references here"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_single_ref():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_with_range():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123[10-50]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://abc123[10-50]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_embedded_in_text():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"Here is the file: @@agptfile:workspace://abc123 — done",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "Here is the file: content:workspace://abc123 — done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_multiple_refs():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://file1 and @@agptfile:workspace://file2[1-5]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == "content:workspace://file1 and content:workspace://file2[1-5]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_invalid_range_zero_start_surfaces_inline():
|
||||
"""expand_file_refs_in_string surfaces [file-ref error: ...] for zero-start ranges."""
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://abc123[0-5]",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "line numbers must be >= 1" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_invalid_range_end_less_than_start_surfaces_inline():
|
||||
"""expand_file_refs_in_string surfaces [file-ref error: ...] when end < start."""
|
||||
result = await expand_file_refs_in_string(
|
||||
"prefix @@agptfile:workspace://abc123[10-5] suffix",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "end line must be >= start line" in result
|
||||
assert "prefix" in result
|
||||
assert "suffix" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_ref_error_surfaces_inline():
|
||||
async def _raise(*args, **kwargs): # noqa: ARG001
|
||||
raise ValueError("file not found")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_raise),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://bad",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "[file-ref error:" in result
|
||||
assert "file not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_flat():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"content": "@@agptfile:workspace://abc123", "other": 42},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["content"] == "content:workspace://abc123"
|
||||
assert result["other"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_nested_dict():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{"outer": {"inner": "@@agptfile:workspace://nested"}},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["outer"]["inner"] == "content:workspace://nested"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_list():
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_always),
|
||||
):
|
||||
result = await expand_file_refs_in_args(
|
||||
{
|
||||
"items": [
|
||||
"@@agptfile:workspace://a",
|
||||
"plain",
|
||||
"@@agptfile:workspace://b[1-3]",
|
||||
]
|
||||
},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result["items"] == [
|
||||
"content:workspace://a",
|
||||
"plain",
|
||||
"content:workspace://b[1-3]",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_empty():
|
||||
result = await expand_file_refs_in_args({}, user_id="u1", session=_make_session())
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_no_refs():
|
||||
result = await expand_file_refs_in_args(
|
||||
{"key": "no refs here", "num": 1},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert result == {"key": "no refs here", "num": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_args_raises_on_file_ref_error():
|
||||
"""expand_file_refs_in_args raises FileRefExpansionError instead of passing
|
||||
the inline error string to the tool, blocking tool execution."""
|
||||
|
||||
async def _raise(*args, **kwargs): # noqa: ARG001
|
||||
raise ValueError("path does not exist")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_raise),
|
||||
):
|
||||
with pytest.raises(FileRefExpansionError) as exc_info:
|
||||
await expand_file_refs_in_args(
|
||||
{"prompt": "@@agptfile:/home/user/missing.txt"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
assert "path does not exist" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-file truncation and aggregate budget
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_per_file_truncation():
|
||||
"""Content exceeding _MAX_EXPAND_CHARS is truncated with a marker."""
|
||||
oversized = "x" * (_MAX_EXPAND_CHARS + 100)
|
||||
|
||||
async def _resolve_oversized(ref: FileRef, _uid: str | None, _s: object) -> str:
|
||||
return oversized
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_oversized),
|
||||
):
|
||||
result = await expand_file_refs_in_string(
|
||||
"@@agptfile:workspace://big-file",
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert len(result) <= _MAX_EXPAND_CHARS + len("\n... [truncated]") + 10
|
||||
assert "[truncated]" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_aggregate_budget_exhausted():
|
||||
"""When the aggregate budget is exhausted, later refs get the budget message."""
|
||||
# Each file returns just under 300K; after ~4 files the 1M budget is used.
|
||||
big_chunk = "y" * 300_000
|
||||
|
||||
async def _resolve_big(ref: FileRef, _uid: str | None, _s: object) -> str:
|
||||
return big_chunk
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.resolve_file_ref",
|
||||
new=AsyncMock(side_effect=_resolve_big),
|
||||
):
|
||||
# 5 refs @ 300K each = 1.5M → last ref(s) should hit the aggregate limit
|
||||
refs = " ".join(f"@@agptfile:workspace://f{i}" for i in range(5))
|
||||
result = await expand_file_refs_in_string(
|
||||
refs,
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert "budget exhausted" in result
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user