mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
85 Commits
fix/copilo
...
swiftyos/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
726542472a | ||
|
|
757ec1f064 | ||
|
|
9442c648a4 | ||
|
|
1c51dd18aa | ||
|
|
6f4f80871d | ||
|
|
e8cca6cd9a | ||
|
|
bf6308e87c | ||
|
|
4e59143d16 | ||
|
|
d5efb6915b | ||
|
|
b9aac42056 | ||
|
|
95651d33da | ||
|
|
b30418d833 | ||
|
|
ed729ddbe2 | ||
|
|
8c7030af0b | ||
|
|
195b14286a | ||
|
|
29ca034e40 | ||
|
|
1d9dd782a8 | ||
|
|
a1cb3d2a91 | ||
|
|
1b91327034 | ||
|
|
c7cdb40c5b | ||
|
|
77fb4419d0 | ||
|
|
9f002ce8f6 | ||
|
|
74691076c6 | ||
|
|
b15ad0df9b | ||
|
|
2136defea8 | ||
|
|
6e61cb103c | ||
|
|
0e72e1f5e7 | ||
|
|
163b0b3c9d | ||
|
|
ef42b17e3b | ||
|
|
a18ffd0b21 | ||
|
|
e40c8c70ce | ||
|
|
9cdcd6793f | ||
|
|
fc64f83331 | ||
|
|
7718c49f05 | ||
|
|
0a1591fce2 | ||
|
|
681bb7c2b4 | ||
|
|
0818cd6683 | ||
|
|
7a39bdfaf8 | ||
|
|
0b151f64e8 | ||
|
|
be2a48aedb | ||
|
|
aeca4dbb79 | ||
|
|
7b85eeaae2 | ||
|
|
4db3be2d61 | ||
|
|
f57a1995d0 | ||
|
|
3928c35928 | ||
|
|
dc77e7b6e6 | ||
|
|
ba75cc28b5 | ||
|
|
15bcdae4e8 | ||
|
|
e9ba7e51db | ||
|
|
d23248f065 | ||
|
|
905373a712 | ||
|
|
ee9d39bc0f | ||
|
|
05aaf7a85e | ||
|
|
9d4dcbd9e0 | ||
|
|
074be7aea6 | ||
|
|
39d28b24fc | ||
|
|
bf79a7748a | ||
|
|
649d4ab7f5 | ||
|
|
223df9d3da | ||
|
|
187ab04745 | ||
|
|
e2d3c8a217 | ||
|
|
647c8ed8d4 | ||
|
|
27d94e395c | ||
|
|
b8f5c208d0 | ||
|
|
ca216dfd7f | ||
|
|
f9f358c526 | ||
|
|
52b3aebf71 | ||
|
|
965b7d3e04 | ||
|
|
c2368f15ff | ||
|
|
9ac3f64d56 | ||
|
|
5035b69c79 | ||
|
|
86af8fc856 | ||
|
|
dfa517300b | ||
|
|
43b25b5e2f | ||
|
|
ab0b537cc7 | ||
|
|
9a8c6ad609 | ||
|
|
e8c50b96d1 | ||
|
|
30e854569a | ||
|
|
301d7cbada | ||
|
|
d95aef7665 | ||
|
|
cb166dd6fb | ||
|
|
3d31f62bf1 | ||
|
|
b8b6c9de23 | ||
|
|
4f6055f494 | ||
|
|
695a185fa1 |
@@ -5,42 +5,13 @@
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
!autogpt_platform/autogpt_libs/poetry.lock
|
||||
!autogpt_platform/autogpt_libs/README.md
|
||||
!autogpt_platform/autogpt_libs/
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/test/e2e_test_data.py
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
!autogpt_platform/market/scripts.py
|
||||
!autogpt_platform/market/schema.prisma
|
||||
!autogpt_platform/market/pyproject.toml
|
||||
!autogpt_platform/market/poetry.lock
|
||||
!autogpt_platform/market/README.md
|
||||
!autogpt_platform/backend/
|
||||
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
!autogpt_platform/frontend/README.md
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
!autogpt_platform/frontend/
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
@@ -64,6 +35,38 @@
|
||||
# Classic - Frontend
|
||||
!classic/frontend/build/web/
|
||||
|
||||
# Explicitly re-ignore some folders
|
||||
.*
|
||||
**/__pycache__
|
||||
# Explicitly re-ignore unwanted files from whitelisted directories
|
||||
# Note: These patterns MUST come after the whitelist rules to take effect
|
||||
|
||||
# Hidden files and directories (but keep frontend .env files needed for build)
|
||||
**/.*
|
||||
!autogpt_platform/frontend/.env
|
||||
!autogpt_platform/frontend/.env.default
|
||||
!autogpt_platform/frontend/.env.production
|
||||
|
||||
# Python artifacts
|
||||
**/__pycache__/
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/.venv/
|
||||
**/.ruff_cache/
|
||||
**/.pytest_cache/
|
||||
**/.coverage
|
||||
**/htmlcov/
|
||||
|
||||
# Node artifacts
|
||||
**/node_modules/
|
||||
**/.next/
|
||||
**/storybook-static/
|
||||
**/playwright-report/
|
||||
**/test-results/
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
!autogpt_platform/frontend/src/**/build/
|
||||
**/target/
|
||||
|
||||
# Logs and temp files
|
||||
**/*.log
|
||||
**/*.tmp
|
||||
|
||||
1229
.github/scripts/detect_overlaps.py
vendored
Normal file
1229
.github/scripts/detect_overlaps.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
42
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
42
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -40,6 +40,48 @@ jobs:
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
# Backend Python/Poetry setup (so Claude can run linting/tests)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v8
|
||||
|
||||
22
.github/workflows/claude-dependabot.yml
vendored
22
.github/workflows/claude-dependabot.yml
vendored
@@ -77,27 +77,15 @@ jobs:
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
22
.github/workflows/claude.yml
vendored
22
.github/workflows/claude.yml
vendored
@@ -93,27 +93,15 @@ jobs:
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
@@ -93,6 +93,6 @@ jobs:
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
34
.github/workflows/docs-claude-review.yml
vendored
34
.github/workflows/docs-claude-review.yml
vendored
@@ -7,6 +7,10 @@ on:
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
|
||||
concurrency:
|
||||
group: claude-docs-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Only run for PRs from members/collaborators
|
||||
@@ -91,5 +95,35 @@ jobs:
|
||||
3. Read corresponding documentation files to verify accuracy
|
||||
4. Provide your feedback as a PR comment
|
||||
|
||||
## IMPORTANT: Comment Marker
|
||||
Start your PR comment with exactly this HTML comment marker on its own line:
|
||||
<!-- CLAUDE_DOCS_REVIEW -->
|
||||
|
||||
This marker is used to identify and replace your comment on subsequent runs.
|
||||
|
||||
Be constructive and specific. If everything looks good, say so!
|
||||
If there are issues, explain what's wrong and suggest how to fix it.
|
||||
|
||||
- name: Delete old Claude review comments
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Get all comment IDs with our marker, sorted by creation date (oldest first)
|
||||
COMMENT_IDS=$(gh api \
|
||||
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
|
||||
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
|
||||
|
||||
# Count comments
|
||||
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
|
||||
|
||||
if [ "$COMMENT_COUNT" -gt 1 ]; then
|
||||
# Delete all but the last (newest) comment
|
||||
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
|
||||
if [ -n "$COMMENT_ID" ]; then
|
||||
echo "Deleting old review comment: $COMMENT_ID"
|
||||
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "No old review comments to clean up"
|
||||
fi
|
||||
|
||||
9
.github/workflows/platform-backend-ci.yml
vendored
9
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,13 +41,18 @@ jobs:
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.12-management
|
||||
image: rabbitmq:4.1.4
|
||||
ports:
|
||||
- 5672:5672
|
||||
- 15672:15672
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
options: >-
|
||||
--health-cmd "rabbitmq-diagnostics -q ping"
|
||||
--health-interval 30s
|
||||
--health-timeout 10s
|
||||
--health-retries 5
|
||||
--health-start-period 10s
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
|
||||
247
.github/workflows/platform-frontend-ci.yml
vendored
247
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,10 +6,16 @@ on:
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
- "autogpt_platform/backend/Dockerfile"
|
||||
- "autogpt_platform/docker-compose.yml"
|
||||
- "autogpt_platform/docker-compose.platform.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
- "autogpt_platform/backend/Dockerfile"
|
||||
- "autogpt_platform/docker-compose.yml"
|
||||
- "autogpt_platform/docker-compose.platform.yml"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -26,7 +32,6 @@ jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
components-changed: ${{ steps.filter.outputs.components }}
|
||||
|
||||
steps:
|
||||
@@ -41,28 +46,17 @@ jobs:
|
||||
components:
|
||||
- 'autogpt_platform/frontend/src/components/**'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies to populate cache
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
@@ -73,22 +67,15 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
@@ -111,22 +98,15 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
@@ -141,10 +121,8 @@ jobs:
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -152,19 +130,11 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
@@ -172,77 +142,125 @@ jobs:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Cache Docker layers
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-frontend-test-
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Run docker compose
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Move cache
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
fi
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
- name: Wait for services to be ready
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Create E2E test data
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
# First try to run the script from inside the container
|
||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||
# Copy the script into the container and run it
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||
echo "❌ Failed to copy script to container"
|
||||
exit 1
|
||||
}
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
@@ -269,7 +287,7 @@ jobs:
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -281,22 +299,15 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
39
.github/workflows/pr-overlap-check.yml
vendored
Normal file
39
.github/workflows/pr-overlap-check.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: PR Overlap Detection
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
branches:
|
||||
- dev
|
||||
- master
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-overlaps:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Need full history for merge testing
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Configure git
|
||||
run: |
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config user.name "github-actions[bot]"
|
||||
|
||||
- name: Run overlap detection
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# Always succeed - this check informs contributors, it shouldn't block merging
|
||||
continue-on-error: true
|
||||
run: |
|
||||
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}
|
||||
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Add cache configuration to a resolved docker-compose file for all services
|
||||
that have a build key, and ensure image names match what docker compose expects.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_BRANCH = "dev"
|
||||
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add cache config to a resolved compose file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
required=True,
|
||||
help="Source compose file to read (should be output of `docker compose config`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-from",
|
||||
default="type=gha",
|
||||
help="Cache source configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-to",
|
||||
default="type=gha,mode=max",
|
||||
help="Cache destination configuration",
|
||||
)
|
||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||
parser.add_argument(
|
||||
f"--{component}-hash",
|
||||
default="",
|
||||
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--git-ref",
|
||||
default="",
|
||||
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
||||
git_ref_scope = ""
|
||||
if args.git_ref:
|
||||
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
||||
|
||||
with open(args.source, "r") as f:
|
||||
compose = yaml.safe_load(f)
|
||||
|
||||
# Get project name from compose file or default
|
||||
project_name = compose.get("name", "autogpt_platform")
|
||||
|
||||
def get_image_name(dockerfile: str, target: str) -> str:
|
||||
"""Generate image name based on Dockerfile folder and build target."""
|
||||
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
||||
if len(dockerfile_parts) >= 2:
|
||||
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
||||
else:
|
||||
folder_name = "app"
|
||||
return f"{project_name}-{folder_name}:{target}"
|
||||
|
||||
def get_build_key(dockerfile: str, target: str) -> str:
|
||||
"""Generate a unique key for a Dockerfile+target combination."""
|
||||
return f"{dockerfile}:{target}"
|
||||
|
||||
def get_component(dockerfile: str) -> str | None:
|
||||
"""Get component name (frontend/backend) from dockerfile path."""
|
||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||
if component in dockerfile:
|
||||
return component
|
||||
return None
|
||||
|
||||
# First pass: collect all services with build configs and identify duplicates
|
||||
# Track which (dockerfile, target) combinations we've seen
|
||||
build_key_to_first_service: dict[str, str] = {}
|
||||
services_to_build: list[str] = []
|
||||
services_to_dedupe: list[str] = []
|
||||
|
||||
for service_name, service_config in compose.get("services", {}).items():
|
||||
if "build" not in service_config:
|
||||
continue
|
||||
|
||||
build_config = service_config["build"]
|
||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||
target = build_config.get("target", "default")
|
||||
build_key = get_build_key(dockerfile, target)
|
||||
|
||||
if build_key not in build_key_to_first_service:
|
||||
# First service with this build config - it will do the actual build
|
||||
build_key_to_first_service[build_key] = service_name
|
||||
services_to_build.append(service_name)
|
||||
else:
|
||||
# Duplicate - will just use the image from the first service
|
||||
services_to_dedupe.append(service_name)
|
||||
|
||||
# Second pass: configure builds and deduplicate
|
||||
modified_services = []
|
||||
for service_name, service_config in compose.get("services", {}).items():
|
||||
if "build" not in service_config:
|
||||
continue
|
||||
|
||||
build_config = service_config["build"]
|
||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||
target = build_config.get("target", "latest")
|
||||
image_name = get_image_name(dockerfile, target)
|
||||
|
||||
# Set image name for all services (needed for both builders and deduped)
|
||||
service_config["image"] = image_name
|
||||
|
||||
if service_name in services_to_dedupe:
|
||||
# Remove build config - this service will use the pre-built image
|
||||
del service_config["build"]
|
||||
continue
|
||||
|
||||
# This service will do the actual build - add cache config
|
||||
cache_from_list = []
|
||||
cache_to_list = []
|
||||
|
||||
component = get_component(dockerfile)
|
||||
if not component:
|
||||
# Skip services that don't clearly match frontend/backend
|
||||
continue
|
||||
|
||||
# Get the hash for this component
|
||||
component_hash = getattr(args, f"{component}_hash")
|
||||
|
||||
# Scope format: platform-{component}-{target}-{hash|ref}
|
||||
# Example: platform-backend-server-abc123
|
||||
|
||||
if "type=gha" in args.cache_from:
|
||||
# 1. Primary: exact hash match (most specific)
|
||||
if component_hash:
|
||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
||||
|
||||
# 2. Fallback: branch-based cache
|
||||
if git_ref_scope:
|
||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
||||
|
||||
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
||||
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
||||
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
||||
|
||||
if "type=gha" in args.cache_to:
|
||||
# Write to both hash-based and branch-based scopes
|
||||
if component_hash:
|
||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
||||
|
||||
if git_ref_scope:
|
||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
||||
|
||||
# Ensure we have at least one cache source/target
|
||||
if not cache_from_list:
|
||||
cache_from_list.append(args.cache_from)
|
||||
if not cache_to_list:
|
||||
cache_to_list.append(args.cache_to)
|
||||
|
||||
build_config["cache_from"] = cache_from_list
|
||||
build_config["cache_to"] = cache_to_list
|
||||
modified_services.append(service_name)
|
||||
|
||||
# Write back to the same file
|
||||
with open(args.source, "w") as f:
|
||||
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
||||
for svc in modified_services:
|
||||
svc_config = compose["services"][svc]
|
||||
build_cfg = svc_config.get("build", {})
|
||||
cache_from_list = build_cfg.get("cache_from", ["none"])
|
||||
cache_to_list = build_cfg.get("cache_to", ["none"])
|
||||
print(f" - {svc}")
|
||||
print(f" image: {svc_config.get('image', 'N/A')}")
|
||||
print(f" cache_from: {cache_from_list}")
|
||||
print(f" cache_to: {cache_to_list}")
|
||||
if services_to_dedupe:
|
||||
print(
|
||||
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
||||
)
|
||||
for svc in services_to_dedupe:
|
||||
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -180,4 +180,6 @@ autogpt_platform/backend/settings.py
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
default_install_hook_types:
|
||||
- pre-commit
|
||||
- pre-push
|
||||
- post-checkout
|
||||
|
||||
default_stages: [pre-commit]
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
@@ -17,6 +24,7 @@ repos:
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
files: ^autogpt_platform/
|
||||
exclude: pnpm-lock\.yaml$
|
||||
stages: [pre-push]
|
||||
|
||||
- repo: local
|
||||
@@ -26,49 +34,106 @@ repos:
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Backend
|
||||
alias: poetry-install-platform-backend
|
||||
entry: poetry -C autogpt_platform/backend install
|
||||
# include autogpt_libs source (since it's a path dependency)
|
||||
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
|
||||
types: [file]
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
|
||||
poetry -C autogpt_platform/backend install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Libs
|
||||
alias: poetry-install-platform-libs
|
||||
entry: poetry -C autogpt_platform/autogpt_libs install
|
||||
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
|
||||
types: [file]
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
|
||||
poetry -C autogpt_platform/autogpt_libs install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: pnpm-install
|
||||
name: Check & Install dependencies - AutoGPT Platform - Frontend
|
||||
alias: pnpm-install-platform-frontend
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
|
||||
pnpm --prefix autogpt_platform/frontend install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt install
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/original_autogpt install
|
||||
'
|
||||
# include forge source (since it's a path dependency)
|
||||
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||
types: [file]
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: poetry -C classic/forge install
|
||||
files: ^classic/forge/poetry\.lock$
|
||||
types: [file]
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/forge install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: poetry -C classic/benchmark install
|
||||
files: ^classic/benchmark/poetry\.lock$
|
||||
types: [file]
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||
poetry -C classic/benchmark install
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, Prisma client must be up-to-date.
|
||||
@@ -76,12 +141,54 @@ repos:
|
||||
- id: prisma-generate
|
||||
name: Prisma Generate - AutoGPT Platform - Backend
|
||||
alias: prisma-generate-platform-backend
|
||||
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
|
||||
entry: >
|
||||
bash -c '
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||
else
|
||||
git diff --cached --name-only
|
||||
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
|
||||
cd autogpt_platform/backend
|
||||
&& poetry run prisma generate
|
||||
&& poetry run gen-prisma-stub
|
||||
'
|
||||
# include everything that triggers poetry install + the prisma schema
|
||||
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
|
||||
types: [file]
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- id: export-api-schema
|
||||
name: Export API schema - AutoGPT Platform - Backend -> Frontend
|
||||
alias: export-api-schema-platform
|
||||
entry: >
|
||||
bash -c '
|
||||
cd autogpt_platform/backend
|
||||
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
&& cd ../frontend
|
||||
&& pnpm prettier --write ./src/app/api/openapi.json
|
||||
'
|
||||
files: ^autogpt_platform/backend/
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: generate-api-client
|
||||
name: Generate API client - AutoGPT Platform - Frontend
|
||||
alias: generate-api-client-platform-frontend
|
||||
entry: >
|
||||
bash -c '
|
||||
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
|
||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
|
||||
else
|
||||
git diff --quiet HEAD -- "$SCHEMA" && exit 0
|
||||
fi;
|
||||
cd autogpt_platform/frontend && pnpm generate:api
|
||||
'
|
||||
always_run: true
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-commit, post-checkout]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.7.2
|
||||
|
||||
@@ -45,6 +45,11 @@ AutoGPT Platform is a monorepo containing:
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
|
||||
169
autogpt_platform/autogpt_libs/poetry.lock
generated
169
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -448,61 +448,61 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "46.0.4"
|
||||
version = "46.0.5"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"},
|
||||
{file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
|
||||
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -516,7 +516,7 @@ nox = ["nox[uv] (>=2024.4.15)"]
|
||||
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
@@ -570,24 +570,25 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.128.7"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"},
|
||||
{file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"},
|
||||
{file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"},
|
||||
{file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
annotated-doc = ">=0.0.2"
|
||||
pydantic = ">=2.7.0"
|
||||
starlette = ">=0.40.0,<0.51.0"
|
||||
starlette = ">=0.40.0,<1.0.0"
|
||||
typing-extensions = ">=4.8.0"
|
||||
typing-inspection = ">=0.4.2"
|
||||
|
||||
[package.extras]
|
||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||
standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||
|
||||
@@ -1062,14 +1063,14 @@ urllib3 = ">=1.26.0,<3"
|
||||
|
||||
[[package]]
|
||||
name = "launchdarkly-server-sdk"
|
||||
version = "9.14.1"
|
||||
version = "9.15.0"
|
||||
description = "LaunchDarkly SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
||||
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
||||
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
||||
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1478,14 +1479,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"},
|
||||
{file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"},
|
||||
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
||||
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2248,14 +2249,14 @@ cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"},
|
||||
{file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"},
|
||||
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
||||
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2436,14 +2437,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"},
|
||||
{file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"},
|
||||
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
||||
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2487,35 +2488,35 @@ python-dateutil = ">=2.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"},
|
||||
{file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"},
|
||||
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
||||
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = "2.27.2"
|
||||
realtime = "2.27.2"
|
||||
storage3 = "2.27.2"
|
||||
supabase-auth = "2.27.2"
|
||||
supabase-functions = "2.27.2"
|
||||
postgrest = "2.28.0"
|
||||
realtime = "2.28.0"
|
||||
storage3 = "2.28.0"
|
||||
supabase-auth = "2.28.0"
|
||||
supabase-functions = "2.28.0"
|
||||
yarl = ">=1.22.0"
|
||||
|
||||
[[package]]
|
||||
name = "supabase-auth"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"},
|
||||
{file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"},
|
||||
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
||||
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2525,14 +2526,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
||||
|
||||
[[package]]
|
||||
name = "supabase-functions"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"},
|
||||
{file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"},
|
||||
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
||||
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2911,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
|
||||
@@ -11,14 +11,14 @@ python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^46.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.0"
|
||||
fastapi = "^0.128.7"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
launchdarkly-server-sdk = "^9.15.0"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.27.2"
|
||||
supabase = "^2.28.0"
|
||||
uvicorn = "^0.40.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -104,6 +104,12 @@ TWITTER_CLIENT_SECRET=
|
||||
# Make a new workspace for your OAuth APP -- trust me
|
||||
# https://linear.app/settings/api/applications/new
|
||||
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||
LINEAR_API_KEY=
|
||||
# Linear project and team IDs for the feature request tracker.
|
||||
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
|
||||
# and in team settings. Used by the chat copilot to file and search feature requests.
|
||||
LINEAR_FEATURE_REQUEST_PROJECT_ID=
|
||||
LINEAR_FEATURE_REQUEST_TEAM_ID=
|
||||
LINEAR_CLIENT_ID=
|
||||
LINEAR_CLIENT_SECRET=
|
||||
|
||||
@@ -184,5 +190,8 @@ ZEROBOUNCE_API_KEY=
|
||||
POSTHOG_API_KEY=
|
||||
POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# Tally Form Integration (pre-populate business understanding on signup)
|
||||
TALLY_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# ============================ DEPENDENCY BUILDER ============================ #
|
||||
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
@@ -51,27 +53,62 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
|
||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||
FROM debian:13-slim AS migrate
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy Node.js from builder (needed for Prisma CLI)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install prisma-client-py directly (much smaller than copying full venv)
|
||||
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
|
||||
FROM debian:13-slim AS server
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool.
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
jq \
|
||||
ripgrep \
|
||||
tree \
|
||||
bubblewrap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
@@ -81,30 +118,25 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
# The .venv includes the generated Prisma client
|
||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
# Copy dependency files + autogpt_libs (path dependency)
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
# Copy backend code + docs (for Copilot docs search)
|
||||
COPY autogpt_platform/backend ./
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
# Install the project package to create entry point scripts in .venv/bin/
|
||||
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
||||
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Common test fixtures for server tests."""
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -88,20 +88,23 @@ async def require_auth(
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
def require_permission(*permissions: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
Dependency function for checking required permissions.
|
||||
All listed permissions must be present.
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
async def check_permissions(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
missing = [p for p in permissions if p not in auth.scopes]
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
detail=f"Missing required permission(s): "
|
||||
f"{', '.join(p.value for p in missing)}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
return check_permissions
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.data import user as user_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .integrations import integrations_router
|
||||
@@ -95,6 +96,43 @@ async def execute_graph_block(
|
||||
return output
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs",
|
||||
tags=["graphs"],
|
||||
status_code=201,
|
||||
dependencies=[
|
||||
Security(
|
||||
require_permission(
|
||||
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
async def create_graph(
|
||||
graph: graph_db.Graph,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
|
||||
),
|
||||
) -> graph_db.GraphModel:
|
||||
"""
|
||||
Create a new agent graph.
|
||||
|
||||
The graph will be validated and assigned a new ID.
|
||||
It is automatically added to the user's library.
|
||||
"""
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
graph_model = graph_db.make_graph_model(graph, auth.user_id)
|
||||
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
||||
graph_model.validate_graph(for_run=False)
|
||||
|
||||
await graph_db.create_graph(graph_model, user_id=auth.user_id)
|
||||
await library_db.create_library_agent(graph_model, auth.user_id)
|
||||
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
|
||||
|
||||
return activated_graph
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
tags=["graphs"],
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence, get_args, get_origin
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from prisma.models import mv_suggested_blocks
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks._base import (
|
||||
AnyBlockSchema,
|
||||
@@ -19,7 +21,6 @@ from backend.blocks._base import (
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -42,6 +43,16 @@ MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||
|
||||
# Boost blocks over marketplace agents in search results
|
||||
BLOCK_SCORE_BOOST = 50.0
|
||||
|
||||
# Block IDs to exclude from search results
|
||||
EXCLUDED_BLOCK_IDS = frozenset(
|
||||
{
|
||||
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
|
||||
}
|
||||
)
|
||||
|
||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||
|
||||
|
||||
@@ -64,8 +75,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
# Skip disabled and excluded blocks
|
||||
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Skip blocks that don't have categories (all should have at least one)
|
||||
if not block.categories:
|
||||
@@ -116,6 +127,9 @@ def get_blocks(
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip excluded blocks
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
# Skip blocks that don't match the category
|
||||
if category and category not in {c.name.lower() for c in block.categories}:
|
||||
continue
|
||||
@@ -255,14 +269,25 @@ async def _build_cached_search_results(
|
||||
"my_agents": 0,
|
||||
}
|
||||
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
normalized_query=normalized_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
# Use hybrid search when query is present, otherwise list all blocks
|
||||
if (include_blocks or include_integrations) and normalized_query:
|
||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||
query=search_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
elif include_blocks or include_integrations:
|
||||
# No query - list all blocks using in-memory approach
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
|
||||
if include_library_agents:
|
||||
library_response = await library_db.list_library_agents(
|
||||
@@ -307,10 +332,14 @@ async def _build_cached_search_results(
|
||||
|
||||
def _collect_block_results(
|
||||
*,
|
||||
normalized_query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Collect all blocks for listing (no search query).
|
||||
|
||||
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
@@ -323,6 +352,10 @@ def _collect_block_results(
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Skip excluded blocks
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
block_info = block.get_info()
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
is_integration = len(credentials) > 0
|
||||
@@ -332,10 +365,6 @@ def _collect_block_results(
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
score = _score_block(block, block_info, normalized_query)
|
||||
if not _should_include_item(score, normalized_query):
|
||||
continue
|
||||
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
@@ -346,8 +375,122 @@ def _collect_block_results(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=score,
|
||||
sort_key=_get_item_name(block_info),
|
||||
score=BLOCK_SCORE_BOOST,
|
||||
sort_key=block_info.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
async def _hybrid_search_blocks(
|
||||
*,
|
||||
query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
"""
|
||||
Search blocks using hybrid search with builder-specific filtering.
|
||||
|
||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||
post-filtering for block/integration types and scoring adjustments.
|
||||
|
||||
Scoring:
|
||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||
to prioritize blocks over marketplace agents in combined results
|
||||
- +30 for exact name match, +15 for prefix name match
|
||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
include_blocks: Whether to include regular blocks
|
||||
include_integrations: Whether to include integration blocks
|
||||
|
||||
Returns:
|
||||
Tuple of (scored_items, block_count, integration_count)
|
||||
"""
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
|
||||
normalized_query = query.strip().lower()
|
||||
|
||||
# Fetch more results to account for post-filtering
|
||||
search_results, _ = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=150,
|
||||
min_score=0.10,
|
||||
)
|
||||
|
||||
# Load all blocks for getting BlockInfo
|
||||
all_blocks = load_all_blocks()
|
||||
|
||||
for result in search_results:
|
||||
block_id = result["content_id"]
|
||||
|
||||
# Skip excluded blocks
|
||||
if block_id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
hybrid_score = result.get("relevance", 0.0)
|
||||
|
||||
# Get the actual block class
|
||||
if block_id not in all_blocks:
|
||||
continue
|
||||
|
||||
block_cls = all_blocks[block_id]
|
||||
block: AnyBlockSchema = block_cls()
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
# Check block/integration filter using metadata
|
||||
is_integration = metadata.get("is_integration", False)
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
# Get block info
|
||||
block_info = block.get_info()
|
||||
|
||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||
|
||||
# Add LLM model match bonus
|
||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||
final_score += 20
|
||||
|
||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||
name = block_info.name.lower()
|
||||
if name == normalized_query:
|
||||
final_score += 30
|
||||
elif name.startswith(normalized_query):
|
||||
final_score += 15
|
||||
|
||||
# Track counts
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=final_score,
|
||||
sort_key=name,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -472,6 +615,8 @@ async def _get_static_counts():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
all_blocks += 1
|
||||
|
||||
@@ -498,47 +643,25 @@ async def _get_static_counts():
|
||||
}
|
||||
|
||||
|
||||
def _contains_type(annotation: Any, target: type) -> bool:
|
||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||
if annotation is target:
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return False
|
||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||
|
||||
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
if _contains_type(field.annotation, LlmModel):
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _score_block(
|
||||
block: AnyBlockSchema,
|
||||
block_info: BlockInfo,
|
||||
normalized_query: str,
|
||||
) -> float:
|
||||
if not normalized_query:
|
||||
return 0.0
|
||||
|
||||
name = block_info.name.lower()
|
||||
description = block_info.description.lower()
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
category_text = " ".join(
|
||||
category.get("category", "").lower() for category in block_info.categories
|
||||
)
|
||||
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
for info in credentials_info
|
||||
for provider in info.provider
|
||||
]
|
||||
provider_text = " ".join(provider_names)
|
||||
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||
|
||||
if _matches_llm_model(block.input_schema, normalized_query):
|
||||
score += 20
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _score_library_agent(
|
||||
agent: library_model.LibraryAgent,
|
||||
normalized_query: str,
|
||||
@@ -645,31 +768,20 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
return providers
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached(ttl_seconds=3600, shared_cache=True)
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
suggested_blocks = []
|
||||
# Sum the number of executions for each block type
|
||||
# Prisma cannot group by nested relations, so we do a raw query
|
||||
# Calculate the cutoff timestamp
|
||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
"""Return the most-executed blocks from the last 14 days.
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
agent_node."agentBlockId" AS block_id,
|
||||
COUNT(execution.id) AS execution_count
|
||||
FROM {schema_prefix}"AgentNodeExecution" execution
|
||||
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||
WHERE execution."endedTime" >= $1::timestamp
|
||||
GROUP BY agent_node."agentBlockId"
|
||||
ORDER BY execution_count DESC;
|
||||
""",
|
||||
timestamp_threshold,
|
||||
)
|
||||
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
|
||||
and returns the top `count` blocks sorted by execution count, excluding
|
||||
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
|
||||
"""
|
||||
results = await mv_suggested_blocks.prisma().find_many()
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input and Output blocks
|
||||
# But ignore Input, Output, Agent, and excluded blocks
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
execution_counts = {row.block_id: row.execution_count for row in results}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
@@ -679,11 +791,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
BlockType.AGENT,
|
||||
):
|
||||
continue
|
||||
# Find the execution count for this block
|
||||
execution_count = next(
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
if block.id in EXCLUDED_BLOCK_IDS:
|
||||
continue
|
||||
execution_count = execution_counts.get(block.id, 0)
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
@@ -27,7 +27,6 @@ class SearchEntry(BaseModel):
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[SearchEntry]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockInfo]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Annotated, Sequence
|
||||
from typing import Annotated, Sequence, cast, get_args
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
@@ -10,6 +10,8 @@ from backend.util.models import Pagination
|
||||
from . import db as builder_db
|
||||
from . import model as builder_model
|
||||
|
||||
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
@@ -49,11 +51,6 @@ async def get_suggestions(
|
||||
Get all suggestions for the Blocks Menu.
|
||||
"""
|
||||
return builder_model.SuggestionsResponse(
|
||||
otto_suggestions=[
|
||||
"What blocks do I need to get started?",
|
||||
"Help me create a list",
|
||||
"Help me feed my data to Google Maps",
|
||||
],
|
||||
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||
providers=[
|
||||
ProviderName.TWITTER,
|
||||
@@ -151,7 +148,7 @@ async def get_providers(
|
||||
async def search(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
||||
filter: Annotated[str | None, fastapi.Query()] = None,
|
||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
@@ -160,9 +157,20 @@ async def search(
|
||||
"""
|
||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||
"""
|
||||
# If no filters are provided, then we will return all types
|
||||
if not filter:
|
||||
filter = [
|
||||
# Parse and validate filter parameter
|
||||
filters: list[builder_model.FilterType]
|
||||
if filter:
|
||||
filter_values = [f.strip() for f in filter.split(",")]
|
||||
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
|
||||
if invalid_filters:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
|
||||
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
|
||||
)
|
||||
filters = cast(list[builder_model.FilterType], filter_values)
|
||||
else:
|
||||
filters = [
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
@@ -174,7 +182,7 @@ async def search(
|
||||
cached_results = await builder_db.get_sorted_search_results(
|
||||
user_id=user_id,
|
||||
search_query=search_query,
|
||||
filters=filter,
|
||||
filters=filters,
|
||||
by_creator=by_creator,
|
||||
)
|
||||
|
||||
@@ -196,7 +204,7 @@ async def search(
|
||||
user_id,
|
||||
builder_model.SearchEntry(
|
||||
search_query=search_query,
|
||||
filter=filter,
|
||||
filter=filters,
|
||||
by_creator=by_creator,
|
||||
search_id=search_id,
|
||||
),
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -1,344 +0,0 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
prisma_client: Prisma | None,
|
||||
) -> None:
|
||||
"""Update tool message in database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The caller should
|
||||
handle this to avoid marking the task as completed with inconsistent state.
|
||||
"""
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
updated_count = await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={"content": content},
|
||||
)
|
||||
# Check if any rows were updated - 0 means message not found
|
||||
if updated_count == 0:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=content,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task will be
|
||||
marked as failed instead of completed to avoid inconsistent state.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -1,29 +1,39 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .tools.models import (
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AgentsFoundResponse,
|
||||
BlockDetailsResponse,
|
||||
BlockListResponse,
|
||||
BlockOutputResponse,
|
||||
ClarificationNeededResponse,
|
||||
@@ -34,15 +44,19 @@ from .tools.models import (
|
||||
InputValidationErrorResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,6 +85,9 @@ class StreamChatRequest(BaseModel):
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
@@ -84,10 +101,8 @@ class CreateSessionResponse(BaseModel):
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
turn_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
@@ -117,12 +132,11 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
class CancelSessionResponse(BaseModel):
|
||||
"""Response model for the cancel session endpoint."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
cancelled: bool
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
@@ -199,6 +213,43 @@ async def create_session(
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sessions/{session_id}",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=204,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def delete_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> Response:
|
||||
"""
|
||||
Delete a chat session.
|
||||
|
||||
Permanently removes a chat session and all its messages.
|
||||
Only the owner can delete their sessions.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
204 No Content on success.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
deleted = await delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -210,7 +261,7 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
@@ -228,24 +279,21 @@ async def get_session(
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_session:
|
||||
# Keep the assistant message (including tool_calls) so the frontend can
|
||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||
# tool parts without output to state "input-available".
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
turn_id=active_session.turn_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
@@ -258,6 +306,51 @@ async def get_session(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||
polls Redis until the task status flips from ``running`` or a timeout
|
||||
(5 s) is reached. Returns only after the cancellation is confirmed.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||
if not active_session:
|
||||
return CancelSessionResponse(cancelled=True, reason="no_active_session")
|
||||
|
||||
await enqueue_cancel_task(session_id)
|
||||
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
session_state = await stream_registry.get_session(session_id)
|
||||
if session_state is None or session_state.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
||||
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
logger.warning(
|
||||
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
|
||||
)
|
||||
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
|
||||
return CancelSessionResponse(cancelled=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
@@ -275,16 +368,15 @@ async def stream_chat_post(
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
@@ -300,10 +392,9 @@ async def stream_chat_post(
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -312,106 +403,95 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
@@ -421,7 +501,7 @@ async def stream_chat_post(
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
@@ -429,11 +509,12 @@ async def stream_chat_post(
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
last_message_id=subscribe_from_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -448,7 +529,7 @@ async def stream_chat_post(
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
@@ -506,23 +587,29 @@ async def stream_chat_post(
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
# Surface error to frontend so it doesn't appear stuck
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -569,17 +656,21 @@ async def resume_session_stream(
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||
active_session, last_message_id = await stream_registry.get_active_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_task:
|
||||
if not active_session:
|
||||
return Response(status_code=204)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=active_task.task_id,
|
||||
# Always replay from the beginning ("0-0") on resume.
|
||||
# We can't use last_message_id because it's the latest ID in the backend
|
||||
# stream, not the latest the frontend received — the gap causes lost
|
||||
# messages. The frontend deduplicates replayed content.
|
||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||
last_message_id="0-0",
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
@@ -591,7 +682,7 @@ async def resume_session_stream(
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
@@ -615,12 +706,12 @@ async def resume_session_stream(
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
active_task.task_id, subscriber_queue
|
||||
await stream_registry.unsubscribe_from_session(
|
||||
session_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
@@ -671,231 +762,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@@ -970,13 +836,12 @@ ToolResponseUnion = (
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| SuggestedGoalResponse
|
||||
| BlockListResponse
|
||||
| BlockDetailsResponse
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests for chat route file_ids validation and enrichment."""
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---- file_ids Pydantic validation (B1) ----
|
||||
|
||||
|
||||
def test_stream_chat_rejects_too_many_file_ids():
|
||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
},
|
||||
)
|
||||
# Should get past validation — 200 streaming response expected
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- UUID format filtering ----
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={
|
||||
"message": "hello",
|
||||
"file_ids": [
|
||||
valid_id,
|
||||
"not-a-uuid",
|
||||
"../../../etc/passwd",
|
||||
"",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# The find_many call should only receive the one valid UUID
|
||||
mock_prisma.find_many.assert_called_once()
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||
|
||||
|
||||
# ---- Cross-workspace file_ids ----
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||
)
|
||||
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hi", "file_ids": [fid]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
@@ -1,82 +0,0 @@
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
assistant_message = ""
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert assistant_message, "Assistant message is empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
had_tool_calls = False
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session.session_id,
|
||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
||||
user_id=session.user_id,
|
||||
):
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = {}
|
||||
mock.categories = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestFindBlockFiltering:
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
def test_excluded_block_types_contains_expected_types(self):
|
||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_filtered_from_results(self):
|
||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="test"
|
||||
)
|
||||
|
||||
# Should only return the standard block, not the INPUT block
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
smart_decision_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import ErrorResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestRunBlockFiltering:
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_returns_error(self):
|
||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="input-block-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_returns_error(self):
|
||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=smart_decision_id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
standard_block = make_mock_block(
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="standard-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
@@ -1,620 +0,0 @@
|
||||
"""CoPilot tools for workspace file operations."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceFileListResponse(ToolResponseBase):
|
||||
"""Response containing list of workspace files."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||
files: list[WorkspaceFileInfoData]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
"""Response containing workspace file content (legacy, for small text files)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
download_url: str
|
||||
preview: str | None = None # First 500 chars for text files
|
||||
|
||||
|
||||
class WorkspaceWriteResponse(ToolResponseBase):
|
||||
"""Response after writing a file to workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
"""Response after deleting a file from workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||
file_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_workspace_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's workspace. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mimeType,
|
||||
size_bytes=f.sizeBytes,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
size_bytes=file_info.sizeBytes,
|
||||
download_url=download_url,
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WriteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for writing files to workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's workspace. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.sizeBytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for deleting files from workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Annotated, List, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import (
|
||||
@@ -14,7 +14,7 @@ from fastapi import (
|
||||
Security,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||
|
||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||
@@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.credentials_store import provider_matches
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
create_mcp_oauth_handler,
|
||||
)
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel):
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
default=None,
|
||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_provider(cls, data: Any) -> Any:
|
||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
||||
if isinstance(data, dict):
|
||||
prov = data.get("provider", "")
|
||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
||||
member = prov.removeprefix("ProviderName.")
|
||||
try:
|
||||
data = {**data, "provider": ProviderName[member].value}
|
||||
except KeyError:
|
||||
pass
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def get_host(cred: Credentials) -> str | None:
|
||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
||||
if isinstance(cred, HostScopedCredentials):
|
||||
return cred.host
|
||||
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
||||
ProviderName.MCP,
|
||||
ProviderName.MCP.value,
|
||||
"ProviderName.MCP",
|
||||
):
|
||||
return (cred.metadata or {}).get("mcp_server_url")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
async def callback(
|
||||
@@ -179,9 +211,7 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||
)
|
||||
|
||||
|
||||
@@ -199,7 +229,7 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -222,7 +252,7 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
host=CredentialsMetaResponse.get_host(cred),
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -322,7 +352,11 @@ async def delete_credentials(
|
||||
|
||||
tokens_revoked = None
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
if provider_matches(provider.value, ProviderName.MCP.value):
|
||||
# MCP uses dynamic per-server OAuth — create handler from metadata
|
||||
handler = create_mcp_oauth_handler(creds)
|
||||
else:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
tokens_revoked = await handler.revoke_tokens(creds)
|
||||
|
||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -144,6 +144,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
@@ -178,7 +179,6 @@ async def test_add_agent_to_library(mocker):
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
# Check that create was called with the expected data including settings
|
||||
create_call_args = mock_library_agent.return_value.create.call_args
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
class FolderValidationError(Exception):
|
||||
"""Raised when folder operations fail validation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FolderAlreadyExistsError(FolderValidationError):
|
||||
"""Raised when a folder with the same name already exists in the location."""
|
||||
|
||||
pass
|
||||
@@ -26,6 +26,95 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
# === Folder Models ===
|
||||
|
||||
|
||||
class LibraryFolder(pydantic.BaseModel):
|
||||
"""Represents a folder for organizing library agents."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
parent_id: str | None = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
agent_count: int = 0 # Direct agents in folder
|
||||
subfolder_count: int = 0 # Direct child folders
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
folder: prisma.models.LibraryFolder,
|
||||
agent_count: int = 0,
|
||||
subfolder_count: int = 0,
|
||||
) -> "LibraryFolder":
|
||||
"""Factory method that constructs a LibraryFolder from a Prisma model."""
|
||||
return LibraryFolder(
|
||||
id=folder.id,
|
||||
user_id=folder.userId,
|
||||
name=folder.name,
|
||||
icon=folder.icon,
|
||||
color=folder.color,
|
||||
parent_id=folder.parentId,
|
||||
created_at=folder.createdAt,
|
||||
updated_at=folder.updatedAt,
|
||||
agent_count=agent_count,
|
||||
subfolder_count=subfolder_count,
|
||||
)
|
||||
|
||||
|
||||
class LibraryFolderTree(LibraryFolder):
|
||||
"""Folder with nested children for tree view."""
|
||||
|
||||
children: list["LibraryFolderTree"] = []
|
||||
|
||||
|
||||
class FolderCreateRequest(pydantic.BaseModel):
|
||||
"""Request model for creating a folder."""
|
||||
|
||||
name: str = pydantic.Field(..., min_length=1, max_length=100)
|
||||
icon: str | None = None
|
||||
color: str | None = pydantic.Field(
|
||||
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
|
||||
)
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class FolderUpdateRequest(pydantic.BaseModel):
|
||||
"""Request model for updating a folder."""
|
||||
|
||||
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
|
||||
icon: str | None = None
|
||||
color: str | None = None
|
||||
|
||||
|
||||
class FolderMoveRequest(pydantic.BaseModel):
|
||||
"""Request model for moving a folder to a new parent."""
|
||||
|
||||
target_parent_id: str | None = None # None = move to root
|
||||
|
||||
|
||||
class BulkMoveAgentsRequest(pydantic.BaseModel):
|
||||
"""Request model for moving multiple agents to a folder."""
|
||||
|
||||
agent_ids: list[str]
|
||||
folder_id: str | None = None # None = move to root
|
||||
|
||||
|
||||
class FolderListResponse(pydantic.BaseModel):
|
||||
"""Response schema for a list of folders."""
|
||||
|
||||
folders: list[LibraryFolder]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class FolderTreeResponse(pydantic.BaseModel):
|
||||
"""Response schema for folder tree structure."""
|
||||
|
||||
tree: list[LibraryFolderTree]
|
||||
|
||||
|
||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||
"""Creator information for a marketplace listing."""
|
||||
|
||||
@@ -120,6 +209,9 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
can_access_graph: bool
|
||||
is_latest_version: bool
|
||||
is_favorite: bool
|
||||
folder_id: str | None = None
|
||||
folder_name: str | None = None # Denormalized for display
|
||||
|
||||
recommended_schedule_cron: str | None = None
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
@@ -259,6 +351,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
folder_id=agent.folderId,
|
||||
folder_name=agent.Folder.name if agent.Folder else None,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
settings=_parse_settings(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
@@ -470,3 +564,7 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
settings: Optional[GraphSettings] = pydantic.Field(
|
||||
default=None, description="User-specific settings for this library agent"
|
||||
)
|
||||
folder_id: Optional[str] = pydantic.Field(
|
||||
default=None,
|
||||
description="Folder ID to move agent to (None to move to root)",
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import fastapi
|
||||
|
||||
from .agents import router as agents_router
|
||||
from .folders import router as folders_router
|
||||
from .presets import router as presets_router
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
router.include_router(presets_router)
|
||||
router.include_router(folders_router)
|
||||
router.include_router(agents_router)
|
||||
|
||||
@@ -41,6 +41,14 @@ async def list_library_agents(
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
folder_id: Optional[str] = Query(
|
||||
None,
|
||||
description="Filter by folder ID",
|
||||
),
|
||||
include_root_only: bool = Query(
|
||||
False,
|
||||
description="Only return agents without a folder (root-level agents)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all agents in the user's library (both created and saved).
|
||||
@@ -51,6 +59,8 @@ async def list_library_agents(
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
folder_id=folder_id,
|
||||
include_root_only=include_root_only,
|
||||
)
|
||||
|
||||
|
||||
@@ -168,6 +178,7 @@ async def update_library_agent(
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
folder_id=payload.folder_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
from typing import Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/folders",
|
||||
tags=["library", "folders", "private"],
|
||||
dependencies=[Security(autogpt_auth_lib.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="List Library Folders",
|
||||
response_model=library_model.FolderListResponse,
|
||||
responses={
|
||||
200: {"description": "List of folders"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def list_folders(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
parent_id: Optional[str] = Query(
|
||||
None,
|
||||
description="Filter by parent folder ID. If not provided, returns root-level folders.",
|
||||
),
|
||||
include_relations: bool = Query(
|
||||
True,
|
||||
description="Include agent and subfolder relations (for counts)",
|
||||
),
|
||||
) -> library_model.FolderListResponse:
|
||||
"""
|
||||
List folders for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
parent_id: Optional parent folder ID to filter by.
|
||||
include_relations: Whether to include agent and subfolder relations for counts.
|
||||
|
||||
Returns:
|
||||
A FolderListResponse containing folders.
|
||||
"""
|
||||
folders = await library_db.list_folders(
|
||||
user_id=user_id,
|
||||
parent_id=parent_id,
|
||||
include_relations=include_relations,
|
||||
)
|
||||
return library_model.FolderListResponse(
|
||||
folders=folders,
|
||||
pagination=library_model.Pagination(
|
||||
total_items=len(folders),
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=len(folders),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tree",
|
||||
summary="Get Folder Tree",
|
||||
response_model=library_model.FolderTreeResponse,
|
||||
responses={
|
||||
200: {"description": "Folder tree structure"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def get_folder_tree(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.FolderTreeResponse:
|
||||
"""
|
||||
Get the full folder tree for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
A FolderTreeResponse containing the nested folder structure.
|
||||
"""
|
||||
tree = await library_db.get_folder_tree(user_id=user_id)
|
||||
return library_model.FolderTreeResponse(tree=tree)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{folder_id}",
|
||||
summary="Get Folder",
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
200: {"description": "Folder details"},
|
||||
404: {"description": "Folder not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def get_folder(
|
||||
folder_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Get a specific folder.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to retrieve.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The requested LibraryFolder.
|
||||
"""
|
||||
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Create Folder",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
201: {"description": "Folder created successfully"},
|
||||
400: {"description": "Validation error"},
|
||||
404: {"description": "Parent folder not found"},
|
||||
409: {"description": "Folder name conflict"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def create_folder(
|
||||
payload: library_model.FolderCreateRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Create a new folder.
|
||||
|
||||
Args:
|
||||
payload: The folder creation request.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The created LibraryFolder.
|
||||
"""
|
||||
return await library_db.create_folder(
|
||||
user_id=user_id,
|
||||
name=payload.name,
|
||||
parent_id=payload.parent_id,
|
||||
icon=payload.icon,
|
||||
color=payload.color,
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{folder_id}",
|
||||
summary="Update Folder",
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
200: {"description": "Folder updated successfully"},
|
||||
400: {"description": "Validation error"},
|
||||
404: {"description": "Folder not found"},
|
||||
409: {"description": "Folder name conflict"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def update_folder(
|
||||
folder_id: str,
|
||||
payload: library_model.FolderUpdateRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Update a folder's properties.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to update.
|
||||
payload: The folder update request.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The updated LibraryFolder.
|
||||
"""
|
||||
return await library_db.update_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
name=payload.name,
|
||||
icon=payload.icon,
|
||||
color=payload.color,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{folder_id}/move",
|
||||
summary="Move Folder",
|
||||
response_model=library_model.LibraryFolder,
|
||||
responses={
|
||||
200: {"description": "Folder moved successfully"},
|
||||
400: {"description": "Validation error (circular reference)"},
|
||||
404: {"description": "Folder or target parent not found"},
|
||||
409: {"description": "Folder name conflict in target location"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def move_folder(
|
||||
folder_id: str,
|
||||
payload: library_model.FolderMoveRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryFolder:
|
||||
"""
|
||||
Move a folder to a new parent.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to move.
|
||||
payload: The move request with target parent.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The moved LibraryFolder.
|
||||
"""
|
||||
return await library_db.move_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
target_parent_id=payload.target_parent_id,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{folder_id}",
|
||||
summary="Delete Folder",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "Folder deleted successfully"},
|
||||
404: {"description": "Folder not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def delete_folder(
|
||||
folder_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> Response:
|
||||
"""
|
||||
Soft-delete a folder and all its contents.
|
||||
|
||||
Args:
|
||||
folder_id: ID of the folder to delete.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 No Content if successful.
|
||||
"""
|
||||
await library_db.delete_folder(
|
||||
folder_id=folder_id,
|
||||
user_id=user_id,
|
||||
soft_delete=True,
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
# === Bulk Agent Operations ===
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/bulk-move",
|
||||
summary="Bulk Move Agents",
|
||||
response_model=list[library_model.LibraryAgent],
|
||||
responses={
|
||||
200: {"description": "Agents moved successfully"},
|
||||
404: {"description": "Folder not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def bulk_move_agents(
|
||||
payload: library_model.BulkMoveAgentsRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
Move multiple agents to a folder.
|
||||
|
||||
Args:
|
||||
payload: The bulk move request with agent IDs and target folder.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgents.
|
||||
"""
|
||||
return await library_db.bulk_move_agents_to_folder(
|
||||
agent_ids=payload.agent_ids,
|
||||
folder_id=payload.folder_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -115,6 +115,8 @@ async def test_get_library_agents_success(
|
||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
||||
page=1,
|
||||
page_size=15,
|
||||
folder_id=None,
|
||||
include_root_only=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
404
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
404
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) API routes.
|
||||
|
||||
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
||||
frontend can list available tools on an MCP server before placing a block.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import Security
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import HTTPClientError, Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = Settings()
|
||||
router = fastapi.APIRouter(tags=["mcp"])
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# ====================== Tool Discovery ====================== #
|
||||
|
||||
|
||||
class DiscoverToolsRequest(BaseModel):
|
||||
"""Request to discover tools on an MCP server."""
|
||||
|
||||
server_url: str = Field(description="URL of the MCP server")
|
||||
auth_token: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Bearer token for authenticated MCP servers",
|
||||
)
|
||||
|
||||
|
||||
class MCPToolResponse(BaseModel):
|
||||
"""A single MCP tool returned by discovery."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
class DiscoverToolsResponse(BaseModel):
|
||||
"""Response containing the list of tools available on an MCP server."""
|
||||
|
||||
tools: list[MCPToolResponse]
|
||||
server_name: str | None = None
|
||||
protocol_version: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/discover-tools",
|
||||
summary="Discover available tools on an MCP server",
|
||||
response_model=DiscoverToolsResponse,
|
||||
)
|
||||
async def discover_tools(
|
||||
request: DiscoverToolsRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> DiscoverToolsResponse:
|
||||
"""
|
||||
Connect to an MCP server and return its available tools.
|
||||
|
||||
If the user has a stored MCP credential for this server URL, it will be
|
||||
used automatically — no need to pass an explicit auth token.
|
||||
"""
|
||||
auth_token = request.auth_token
|
||||
|
||||
# Auto-use stored MCP credential when no explicit token is provided.
|
||||
if not auth_token:
|
||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
# Find the freshest credential for this server URL
|
||||
best_cred: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
||||
):
|
||||
if best_cred is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best_cred.access_token_expires_at or 0)
|
||||
):
|
||||
best_cred = cred
|
||||
if best_cred:
|
||||
# Refresh the token if expired before using it
|
||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||
logger.info(
|
||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||
f"expires_at={best_cred.access_token_expires_at}"
|
||||
)
|
||||
auth_token = best_cred.access_token.get_secret_value()
|
||||
|
||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||
|
||||
try:
|
||||
init_result = await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
except HTTPClientError as e:
|
||||
if e.status_code in (401, 403):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401,
|
||||
detail="This MCP server requires authentication. "
|
||||
"Please provide a valid auth token.",
|
||||
)
|
||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||
except MCPClientError as e:
|
||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to connect to MCP server: {e}",
|
||||
)
|
||||
|
||||
return DiscoverToolsResponse(
|
||||
tools=[
|
||||
MCPToolResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
server_name=(
|
||||
init_result.get("serverInfo", {}).get("name")
|
||||
or urlparse(request.server_url).hostname
|
||||
or "MCP"
|
||||
),
|
||||
protocol_version=init_result.get("protocolVersion"),
|
||||
)
|
||||
|
||||
|
||||
# ======================== OAuth Flow ======================== #
|
||||
|
||||
|
||||
class MCPOAuthLoginRequest(BaseModel):
|
||||
"""Request to start an OAuth flow for an MCP server."""
|
||||
|
||||
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
||||
|
||||
|
||||
class MCPOAuthLoginResponse(BaseModel):
|
||||
"""Response with the OAuth login URL for the user to authenticate."""
|
||||
|
||||
login_url: str
|
||||
state_token: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/login",
|
||||
summary="Initiate OAuth login for an MCP server",
|
||||
)
|
||||
async def mcp_oauth_login(
|
||||
request: MCPOAuthLoginRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> MCPOAuthLoginResponse:
|
||||
"""
|
||||
Discover OAuth metadata from the MCP server and return a login URL.
|
||||
|
||||
1. Discovers the protected-resource metadata (RFC 9728)
|
||||
2. Fetches the authorization server metadata (RFC 8414)
|
||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||
4. Returns the authorization URL for the frontend to open in a popup
|
||||
"""
|
||||
client = MCPClient(request.server_url)
|
||||
|
||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||
protected_resource = await client.discover_auth()
|
||||
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
if protected_resource and protected_resource.get("authorization_servers"):
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
|
||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
else:
|
||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
||||
# and serve OAuth metadata directly without protected-resource metadata.
|
||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||
# the correct audience for the token (RFC 8707 resource is optional).
|
||||
resource_url = None
|
||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||
|
||||
if (
|
||||
not metadata
|
||||
or "authorization_endpoint" not in metadata
|
||||
or "token_endpoint" not in metadata
|
||||
):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail="This MCP server does not advertise OAuth support. "
|
||||
"You may need to provide an auth token manually.",
|
||||
)
|
||||
|
||||
authorize_url = metadata["authorization_endpoint"]
|
||||
token_url = metadata["token_endpoint"]
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
revoke_url = metadata.get("revocation_endpoint")
|
||||
|
||||
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
if not frontend_base_url:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Frontend base URL is not configured.",
|
||||
)
|
||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||
|
||||
client_id = ""
|
||||
client_secret = ""
|
||||
if registration_endpoint:
|
||||
reg_result = await _register_mcp_client(
|
||||
registration_endpoint, redirect_uri, request.server_url
|
||||
)
|
||||
if reg_result:
|
||||
client_id = reg_result.get("client_id", "")
|
||||
client_secret = reg_result.get("client_secret", "")
|
||||
|
||||
if not client_id:
|
||||
client_id = "autogpt-platform"
|
||||
|
||||
# Step 4: Store state token with OAuth metadata for the callback
|
||||
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
||||
"scopes_supported", []
|
||||
)
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id,
|
||||
ProviderName.MCP.value,
|
||||
scopes,
|
||||
state_metadata={
|
||||
"authorize_url": authorize_url,
|
||||
"token_url": token_url,
|
||||
"revoke_url": revoke_url,
|
||||
"resource_url": resource_url,
|
||||
"server_url": request.server_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 5: Build and return the login URL
|
||||
handler = MCPOAuthHandler(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
authorize_url=authorize_url,
|
||||
token_url=token_url,
|
||||
resource_url=resource_url,
|
||||
)
|
||||
login_url = handler.get_login_url(
|
||||
scopes, state_token, code_challenge=code_challenge
|
||||
)
|
||||
|
||||
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
||||
|
||||
|
||||
class MCPOAuthCallbackRequest(BaseModel):
|
||||
"""Request to exchange an OAuth code for tokens."""
|
||||
|
||||
code: str = Field(description="Authorization code from OAuth callback")
|
||||
state_token: str = Field(description="State token for CSRF verification")
|
||||
|
||||
|
||||
class MCPOAuthCallbackResponse(BaseModel):
|
||||
"""Response after successfully storing OAuth credentials."""
|
||||
|
||||
credential_id: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/callback",
|
||||
summary="Exchange OAuth code for MCP tokens",
|
||||
)
|
||||
async def mcp_oauth_callback(
|
||||
request: MCPOAuthCallbackRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> CredentialsMetaResponse:
|
||||
"""
|
||||
Exchange the authorization code for tokens and store the credential.
|
||||
|
||||
The frontend calls this after receiving the OAuth code from the popup.
|
||||
On success, subsequent ``/discover-tools`` calls for the same server URL
|
||||
will automatically use the stored credential.
|
||||
"""
|
||||
valid_state = await creds_manager.store.verify_state_token(
|
||||
user_id, request.state_token, ProviderName.MCP.value
|
||||
)
|
||||
if not valid_state:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid or expired state token.",
|
||||
)
|
||||
|
||||
meta = valid_state.state_metadata
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
if not frontend_base_url:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Frontend base URL is not configured.",
|
||||
)
|
||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||
|
||||
handler = MCPOAuthHandler(
|
||||
client_id=meta["client_id"],
|
||||
client_secret=meta.get("client_secret", ""),
|
||||
redirect_uri=redirect_uri,
|
||||
authorize_url=meta["authorize_url"],
|
||||
token_url=meta["token_url"],
|
||||
revoke_url=meta.get("revoke_url"),
|
||||
resource_url=meta.get("resource_url"),
|
||||
)
|
||||
|
||||
try:
|
||||
credentials = await handler.exchange_code_for_tokens(
|
||||
request.code, valid_state.scopes, valid_state.code_verifier
|
||||
)
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail=f"OAuth token exchange failed: {e}",
|
||||
)
|
||||
|
||||
# Enrich credential metadata for future lookup and token refresh
|
||||
if credentials.metadata is None:
|
||||
credentials.metadata = {}
|
||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||
|
||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||
credentials.title = f"MCP: {hostname}"
|
||||
|
||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||
try:
|
||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
for old in old_creds:
|
||||
if (
|
||||
isinstance(old, OAuth2Credentials)
|
||||
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
|
||||
):
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||
logger.info(
|
||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=credentials.metadata.get("mcp_server_url"),
|
||||
)
|
||||
|
||||
|
||||
# ======================== Helpers ======================== #
|
||||
|
||||
|
||||
async def _register_mcp_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
server_url: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
||||
try:
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
registration_endpoint,
|
||||
json={
|
||||
"client_name": "AutoGPT Platform",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if isinstance(data, dict) and "client_id" in data:
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||
return None
|
||||
436
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
436
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Tests for MCP API routes.
|
||||
|
||||
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
|
||||
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.auth import get_user_id
|
||||
|
||||
from backend.api.features.mcp.routes import router
|
||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||
from backend.util.request import HTTPClientError
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def client():
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestDiscoverTools:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_success(self, client):
|
||||
mock_tools = [
|
||||
MCPTool(
|
||||
name="get_weather",
|
||||
description="Get weather for a city",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
),
|
||||
MCPTool(
|
||||
name="add_numbers",
|
||||
description="Add two numbers",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={
|
||||
"protocolVersion": "2025-03-26",
|
||||
"serverInfo": {"name": "test-server"},
|
||||
}
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["tools"]) == 2
|
||||
assert data["tools"][0]["name"] == "get_weather"
|
||||
assert data["tools"][1]["name"] == "add_numbers"
|
||||
assert data["server_name"] == "test-server"
|
||||
assert data["protocol_version"] == "2025-03-26"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_with_auth_token(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"auth_token": "my-secret-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
MockClient.assert_called_once_with(
|
||||
"https://mcp.example.com/mcp",
|
||||
auth_token="my-secret-token",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||
"""When no explicit token is given, stored MCP credentials are used."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
stored_cred = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title="MCP: example.com",
|
||||
access_token=SecretStr("stored-token-123"),
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
MockClient.assert_called_once_with(
|
||||
"https://mcp.example.com/mcp",
|
||||
auth_token="stored-token-123",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_mcp_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("Connection refused")
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "Connection refused" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_generic_error(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://timeout.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "Failed to connect" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auth_required(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "requires authentication" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_forbidden(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
):
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "requires authentication" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_missing_url(self, client):
|
||||
response = await client.post("/discover-tools", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestOAuthLogin:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_success(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
patch(
|
||||
"backend.api.features.mcp.routes._register_mcp_client"
|
||||
) as mock_register,
|
||||
):
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(
|
||||
return_value={
|
||||
"authorization_servers": ["https://auth.sentry.io"],
|
||||
"resource": "https://mcp.sentry.dev/mcp",
|
||||
"scopes_supported": ["openid"],
|
||||
}
|
||||
)
|
||||
instance.discover_auth_server_metadata = AsyncMock(
|
||||
return_value={
|
||||
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
||||
"token_endpoint": "https://auth.sentry.io/token",
|
||||
"registration_endpoint": "https://auth.sentry.io/register",
|
||||
}
|
||||
)
|
||||
mock_register.return_value = {
|
||||
"client_id": "registered-client-id",
|
||||
"client_secret": "registered-secret",
|
||||
}
|
||||
mock_cm.store.store_state_token = AsyncMock(
|
||||
return_value=("state-token-123", "code-challenge-abc")
|
||||
)
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "login_url" in data
|
||||
assert data["state_token"] == "state-token-123"
|
||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
||||
assert "registered-client-id" in data["login_url"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_no_oauth_support(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(return_value=None)
|
||||
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "does not advertise OAuth" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_fallback_to_public_client(self, client):
|
||||
"""When DCR is unavailable, falls back to default public client ID."""
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
):
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(
|
||||
return_value={
|
||||
"authorization_servers": ["https://auth.example.com"],
|
||||
"resource": "https://mcp.example.com/mcp",
|
||||
}
|
||||
)
|
||||
instance.discover_auth_server_metadata = AsyncMock(
|
||||
return_value={
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
# No registration_endpoint
|
||||
}
|
||||
)
|
||||
mock_cm.store.store_state_token = AsyncMock(
|
||||
return_value=("state-abc", "challenge-xyz")
|
||||
)
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "autogpt-platform" in data["login_url"]
|
||||
|
||||
|
||||
class TestOAuthCallback:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_success(self, client):
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
mock_creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
title=None,
|
||||
access_token=SecretStr("access-token-xyz"),
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
metadata={
|
||||
"mcp_token_url": "https://auth.sentry.io/token",
|
||||
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||
):
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
# Mock state verification
|
||||
mock_state = AsyncMock()
|
||||
mock_state.state_metadata = {
|
||||
"authorize_url": "https://auth.sentry.io/authorize",
|
||||
"token_url": "https://auth.sentry.io/token",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-secret",
|
||||
"server_url": "https://mcp.sentry.dev/mcp",
|
||||
}
|
||||
mock_state.scopes = ["openid"]
|
||||
mock_state.code_verifier = "verifier-123"
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||
mock_cm.create = AsyncMock()
|
||||
|
||||
handler_instance = MockHandler.return_value
|
||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||
return_value=mock_creds
|
||||
)
|
||||
|
||||
# Mock old credential cleanup
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["provider"] == "mcp"
|
||||
assert data["type"] == "oauth2"
|
||||
mock_cm.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_invalid_state(self, client):
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "auth-code", "state_token": "bad-state"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid or expired" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_token_exchange_fails(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||
):
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
mock_state = AsyncMock()
|
||||
mock_state.state_metadata = {
|
||||
"authorize_url": "https://auth.example.com/authorize",
|
||||
"token_url": "https://auth.example.com/token",
|
||||
"client_id": "cid",
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
}
|
||||
mock_state.scopes = []
|
||||
mock_state.code_verifier = "v"
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||
|
||||
handler_instance = MockHandler.return_value
|
||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||
side_effect=RuntimeError("Token exchange failed")
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "bad-code", "state_token": "state"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "token exchange failed" in response.json()["detail"].lower()
|
||||
@@ -9,15 +9,26 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _contains_type(annotation: Any, target: type) -> bool:
|
||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||
if annotation is target:
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return False
|
||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
@@ -188,45 +199,51 @@ class BlockHandler(ContentHandler):
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Skip disabled blocks - they shouldn't be indexed
|
||||
if block_instance.disabled:
|
||||
continue
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
if block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
if block_instance.description:
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
if block_instance.categories:
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
# Add input schema field descriptions
|
||||
block_input_fields = block_instance.input_schema.model_fields
|
||||
parts += [
|
||||
f"{field_name}: {field_info.description}"
|
||||
for field_name, field_info in block_input_fields.items()
|
||||
if field_info.description
|
||||
]
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in categories] if categories else []
|
||||
[cat.value for cat in block_instance.categories]
|
||||
if block_instance.categories
|
||||
else []
|
||||
)
|
||||
|
||||
# Extract provider names from credentials fields
|
||||
credentials_info = (
|
||||
block_instance.input_schema.get_credentials_fields_info()
|
||||
)
|
||||
is_integration = len(credentials_info) > 0
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
for info in credentials_info.values()
|
||||
for provider in info.provider
|
||||
]
|
||||
|
||||
# Check if block has LlmModel field in input schema
|
||||
has_llm_model_field = any(
|
||||
_contains_type(field.annotation, LlmModel)
|
||||
for field in block_instance.input_schema.model_fields.values()
|
||||
)
|
||||
|
||||
items.append(
|
||||
@@ -235,8 +252,11 @@ class BlockHandler(ContentHandler):
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"name": block_instance.name,
|
||||
"categories": categories_list,
|
||||
"providers": provider_names,
|
||||
"has_llm_model_field": has_llm_model_field,
|
||||
"is_integration": is_integration,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
|
||||
@@ -82,9 +82,10 @@ async def test_block_handler_get_missing_items(mocker):
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_field = MagicMock()
|
||||
mock_field.description = "Math expression to evaluate"
|
||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
@@ -309,19 +310,19 @@ async def test_content_handlers_registry():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
# Mock block with empty values (all attributes exist but are falsy)
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_instance.description = ""
|
||||
mock_block_instance.categories = set()
|
||||
mock_block_instance.input_schema.model_fields = {}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
@@ -352,6 +353,8 @@ async def test_block_handler_skips_failed_blocks():
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_instance.input_schema.model_fields = {}
|
||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
|
||||
@@ -126,6 +126,9 @@ v1_router = APIRouter()
|
||||
########################################################
|
||||
|
||||
|
||||
_tally_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
@@ -134,6 +137,24 @@ v1_router = APIRouter()
|
||||
)
|
||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||
user = await get_or_create_user(user_data)
|
||||
|
||||
# Fire-and-forget: populate business understanding from Tally form.
|
||||
# We use created_at proximity instead of an is_new flag because
|
||||
# get_or_create_user is cached — a separate is_new return value would be
|
||||
# unreliable on repeated calls within the cache TTL.
|
||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||
if age_seconds < 30:
|
||||
try:
|
||||
from backend.data.tally import populate_understanding_from_tally
|
||||
|
||||
task = asyncio.create_task(
|
||||
populate_understanding_from_tally(user.id, user.email)
|
||||
)
|
||||
_tally_background_tasks.add(task)
|
||||
task.add_done_callback(_tally_background_tasks.discard)
|
||||
except Exception:
|
||||
logger.debug("Failed to start Tally population task", exc_info=True)
|
||||
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -43,6 +43,7 @@ def test_get_or_create_user_route(
|
||||
) -> None:
|
||||
"""Test get or create user endpoint"""
|
||||
mock_user = Mock()
|
||||
mock_user.created_at = datetime.now(timezone.utc)
|
||||
mock_user.model_dump.return_value = {
|
||||
"id": test_user_id,
|
||||
"email": "test@example.com",
|
||||
|
||||
@@ -3,15 +3,29 @@ Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi import Query, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.workspace import get_workspace, get_workspace_file
|
||||
from backend.data.workspace import (
|
||||
WorkspaceFile,
|
||||
count_workspace_files,
|
||||
get_or_create_workspace,
|
||||
get_workspace,
|
||||
get_workspace_file,
|
||||
get_workspace_total_size,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -44,11 +58,11 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file) -> Response:
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mimeType,
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
@@ -56,7 +70,7 @@ def _create_streaming_response(content: bytes, file) -> Response:
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file) -> Response:
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -66,38 +80,53 @@ async def _create_file_download_response(file) -> Response:
|
||||
storage = await get_workspace_storage()
|
||||
|
||||
# For local storage, stream the file directly
|
||||
if file.storagePath.startswith("local://"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||
url = await storage.get_download_url(file.storage_path, expires_in=300)
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
logger.error(
|
||||
f"Failed to get signed URL for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {e}",
|
||||
f"(storagePath={file.storage_path}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||
f"(storagePath={file.storage_path}): {fallback_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class StorageUsageResponse(BaseModel):
|
||||
used_bytes: int
|
||||
limit_bytes: int
|
||||
used_percent: float
|
||||
file_count: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
@@ -120,3 +149,120 @@ async def download_file(
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/files/upload",
|
||||
summary="Upload file to workspace",
|
||||
)
|
||||
async def upload_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file: UploadFile,
|
||||
session_id: str | None = Query(default=None),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to the user's workspace.
|
||||
|
||||
Files are stored in session-scoped paths when session_id is provided,
|
||||
so the agent's session-scoped tools can discover them automatically.
|
||||
"""
|
||||
config = Config()
|
||||
|
||||
# Sanitize filename — strip any directory components
|
||||
filename = os.path.basename(file.filename or "upload") or "upload"
|
||||
|
||||
# Read file content with early abort on size limit
|
||||
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
||||
chunks: list[bytes] = []
|
||||
total_size = 0
|
||||
while chunk := await file.read(64 * 1024): # 64KB chunks
|
||||
total_size += len(chunk)
|
||||
if total_size > max_file_bytes:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
|
||||
# Get or create workspace
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
current_usage = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded",
|
||||
"used_bytes": current_usage,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
"used_percent": round(used_percent, 1),
|
||||
},
|
||||
)
|
||||
|
||||
# Warn at 80% usage
|
||||
if (
|
||||
storage_limit_bytes
|
||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||
):
|
||||
logger.warning(
|
||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||
)
|
||||
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
# Write file via WorkspaceManager
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
workspace_file = await manager.write_file(content, filename)
|
||||
|
||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||
# If a concurrent upload pushed us over the limit, undo this write.
|
||||
new_total = await get_workspace_total_size(workspace.id)
|
||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=413,
|
||||
detail={
|
||||
"message": "Storage limit exceeded (concurrent upload)",
|
||||
"used_bytes": new_total,
|
||||
"limit_bytes": storage_limit_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_id=workspace_file.id,
|
||||
name=workspace_file.name,
|
||||
path=workspace_file.path,
|
||||
mime_type=workspace_file.mime_type,
|
||||
size_bytes=workspace_file.size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storage/usage",
|
||||
summary="Get workspace storage usage",
|
||||
)
|
||||
async def get_storage_usage(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
) -> StorageUsageResponse:
|
||||
"""
|
||||
Get storage usage information for the user's workspace.
|
||||
"""
|
||||
config = Config()
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
used_bytes = await get_workspace_total_size(workspace.id)
|
||||
file_count = await count_workspace_files(workspace.id)
|
||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||
|
||||
return StorageUsageResponse(
|
||||
used_bytes=used_bytes,
|
||||
limit_bytes=limit_bytes,
|
||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
"""Tests for workspace file upload and download routes."""
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.workspace import routes as workspace_routes
|
||||
from backend.data.workspace import WorkspaceFile
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(workspace_routes.router)
|
||||
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def _value_error_handler(
|
||||
request: fastapi.Request, exc: ValueError
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||
|
||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
MOCK_FILE = WorkspaceFile(
|
||||
id="file-aaa-bbb",
|
||||
workspace_id="ws-1",
|
||||
created_at=_NOW,
|
||||
updated_at=_NOW,
|
||||
name="hello.txt",
|
||||
path="/session/hello.txt",
|
||||
mime_type="text/plain",
|
||||
size_bytes=13,
|
||||
storage_path="local://hello.txt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _upload(
|
||||
filename: str = "hello.txt",
|
||||
content: bytes = b"Hello, world!",
|
||||
content_type: str = "text/plain",
|
||||
):
|
||||
"""Helper to POST a file upload."""
|
||||
return client.post(
|
||||
"/files/upload?session_id=sess-1",
|
||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||
)
|
||||
|
||||
|
||||
# ---- Happy path ----
|
||||
|
||||
|
||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["file_id"] == "file-aaa-bbb"
|
||||
assert data["name"] == "hello.txt"
|
||||
assert data["size_bytes"] == 13
|
||||
|
||||
|
||||
# ---- Per-file size limit ----
|
||||
|
||||
|
||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||
cfg.return_value.max_workspace_storage_mb = 500
|
||||
|
||||
response = _upload(content=b"x" * 1024)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
# ---- Storage quota exceeded ----
|
||||
|
||||
|
||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Current usage already at limit
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=500 * 1024 * 1024,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
assert "Storage limit exceeded" in response.text
|
||||
|
||||
|
||||
# ---- Post-write quota race (B2) ----
|
||||
|
||||
|
||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||
"""If a concurrent upload tips the total over the limit after write,
|
||||
the file should be soft-deleted and 413 returned."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
# Pre-write check passes (under limit), but post-write check fails
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
mock_delete = mocker.patch(
|
||||
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = _upload()
|
||||
assert response.status_code == 413
|
||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||
|
||||
|
||||
# ---- Any extension accepted (no allowlist) ----
|
||||
|
||||
|
||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="data.xyz", content=b"arbitrary")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---- Virus scan rejection ----
|
||||
|
||||
|
||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||
assert response.status_code == 400
|
||||
assert "Virus detected" in response.text
|
||||
mock_manager.write_file.assert_not_called()
|
||||
|
||||
|
||||
# ---- No file extension ----
|
||||
|
||||
|
||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||
"""Files without an extension should be accepted and stored as-is."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
response = _upload(
|
||||
filename="Makefile",
|
||||
content=b"all:\n\techo hello",
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_manager.write_file.assert_called_once()
|
||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||
|
||||
|
||||
# ---- Filename sanitization (SF5) ----
|
||||
|
||||
|
||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||
"""Path-traversal filenames should be reduced to their basename."""
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.scan_content_safe",
|
||||
return_value=None,
|
||||
)
|
||||
mock_manager = mocker.MagicMock()
|
||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
# Filename with traversal
|
||||
_upload(filename="../../etc/passwd.txt")
|
||||
|
||||
# write_file should have been called with just the basename
|
||||
mock_manager.write_file.assert_called_once()
|
||||
call_args = mock_manager.write_file.call_args
|
||||
assert call_args[0][1] == "passwd.txt"
|
||||
|
||||
|
||||
# ---- Download ----
|
||||
|
||||
|
||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace",
|
||||
return_value=MOCK_WORKSPACE,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.workspace.routes.get_workspace_file",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/files/some-file-id/download")
|
||||
assert response.status_code == 404
|
||||
@@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes
|
||||
import backend.api.features.library.db
|
||||
import backend.api.features.library.model
|
||||
import backend.api.features.library.routes
|
||||
import backend.api.features.mcp.routes as mcp_routes
|
||||
import backend.api.features.oauth
|
||||
import backend.api.features.otto.routes
|
||||
import backend.api.features.postmark.postmark
|
||||
@@ -40,9 +41,9 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
from backend.api.features.library.exceptions import (
|
||||
FolderAlreadyExistsError,
|
||||
FolderValidationError,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
@@ -122,21 +123,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
@@ -276,6 +265,10 @@ async def validation_error_handler(
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(
|
||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||
)
|
||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
@@ -343,6 +336,11 @@ app.include_router(
|
||||
tags=["workspace"],
|
||||
prefix="/api/workspace",
|
||||
)
|
||||
app.include_router(
|
||||
mcp_routes.router,
|
||||
tags=["v2", "mcp"],
|
||||
prefix="/api/mcp",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.oauth.router,
|
||||
tags=["oauth"],
|
||||
|
||||
@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
|
||||
# Run the last process in the foreground.
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in processes:
|
||||
for process in reversed(processes):
|
||||
try:
|
||||
process.stop()
|
||||
except Exception as e:
|
||||
@@ -38,7 +38,9 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
run_processes(
|
||||
@@ -48,6 +50,7 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ class BlockType(Enum):
|
||||
AI = "AI"
|
||||
AYRSHARE = "Ayrshare"
|
||||
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||
MCP_TOOL = "MCP Tool"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
|
||||
@@ -126,6 +126,7 @@ class PrintToConsoleBlock(Block):
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
import shlex
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -20,6 +20,13 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.sandbox_files import (
|
||||
SandboxFileOutput,
|
||||
extract_and_store_sandbox_files,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.utils import ExecutionContext
|
||||
|
||||
|
||||
class ClaudeCodeExecutionError(Exception):
|
||||
@@ -174,22 +181,15 @@ class ClaudeCodeBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class FileOutput(BaseModel):
|
||||
"""A file extracted from the sandbox."""
|
||||
|
||||
path: str
|
||||
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||
name: str
|
||||
content: str
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
response: str = SchemaField(
|
||||
description="The output/response from Claude Code execution"
|
||||
)
|
||||
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||
files: list[SandboxFileOutput] = SchemaField(
|
||||
description=(
|
||||
"List of text files created/modified by Claude Code during this execution. "
|
||||
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||
"Each file has 'path', 'relative_path', 'name', 'content', and 'workspace_ref' fields. "
|
||||
"workspace_ref contains a workspace:// URI if the file was stored to workspace."
|
||||
)
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
@@ -252,6 +252,7 @@ class ClaudeCodeBlock(Block):
|
||||
"relative_path": "index.html",
|
||||
"name": "index.html",
|
||||
"content": "<html>Hello World</html>",
|
||||
"workspace_ref": None,
|
||||
}
|
||||
],
|
||||
),
|
||||
@@ -267,11 +268,12 @@ class ClaudeCodeBlock(Block):
|
||||
"execute_claude_code": lambda *args, **kwargs: (
|
||||
"Created index.html with hello world content", # response
|
||||
[
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
SandboxFileOutput(
|
||||
path="/home/user/index.html",
|
||||
relative_path="index.html",
|
||||
name="index.html",
|
||||
content="<html>Hello World</html>",
|
||||
workspace_ref=None,
|
||||
)
|
||||
], # files
|
||||
"User: Create a hello world HTML file\n"
|
||||
@@ -294,7 +296,8 @@ class ClaudeCodeBlock(Block):
|
||||
existing_sandbox_id: str,
|
||||
conversation_history: str,
|
||||
dispose_sandbox: bool,
|
||||
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||
execution_context: "ExecutionContext",
|
||||
) -> tuple[str, list[SandboxFileOutput], str, str, str]:
|
||||
"""
|
||||
Execute Claude Code in an E2B sandbox.
|
||||
|
||||
@@ -449,14 +452,18 @@ class ClaudeCodeBlock(Block):
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
# Extract files created/modified during this run
|
||||
files = await self._extract_files(
|
||||
sandbox, working_directory, start_timestamp
|
||||
# Extract files created/modified during this run and store to workspace
|
||||
sandbox_files = await extract_and_store_sandbox_files(
|
||||
sandbox=sandbox,
|
||||
working_directory=working_directory,
|
||||
execution_context=execution_context,
|
||||
since_timestamp=start_timestamp,
|
||||
text_only=True,
|
||||
)
|
||||
|
||||
return (
|
||||
response,
|
||||
files,
|
||||
sandbox_files, # Already SandboxFileOutput objects
|
||||
new_conversation_history,
|
||||
current_session_id,
|
||||
sandbox_id,
|
||||
@@ -471,140 +478,6 @@ class ClaudeCodeBlock(Block):
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
async def _extract_files(
|
||||
self,
|
||||
sandbox: BaseAsyncSandbox,
|
||||
working_directory: str,
|
||||
since_timestamp: str | None = None,
|
||||
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||
"""
|
||||
Extract text files created/modified during this Claude Code execution.
|
||||
|
||||
Args:
|
||||
sandbox: The E2B sandbox instance
|
||||
working_directory: Directory to search for files
|
||||
since_timestamp: ISO timestamp - only return files modified after this time
|
||||
|
||||
Returns:
|
||||
List of FileOutput objects with path, relative_path, name, and content
|
||||
"""
|
||||
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||
|
||||
# Text file extensions we can safely read as text
|
||||
text_extensions = {
|
||||
".txt",
|
||||
".md",
|
||||
".html",
|
||||
".htm",
|
||||
".css",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".py",
|
||||
".rb",
|
||||
".php",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".sql",
|
||||
".graphql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerfile",
|
||||
"Dockerfile",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".mdx",
|
||||
".rst",
|
||||
".tex",
|
||||
".csv",
|
||||
".log",
|
||||
}
|
||||
|
||||
try:
|
||||
# List files recursively using find command
|
||||
# Exclude node_modules and .git directories, but allow hidden files
|
||||
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||
# Filter by timestamp to only get files created/modified during this run
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
timestamp_filter = ""
|
||||
if since_timestamp:
|
||||
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||
find_result = await sandbox.commands.run(
|
||||
f"find {safe_working_dir} -type f "
|
||||
f"{timestamp_filter}"
|
||||
f"-not -path '*/node_modules/*' "
|
||||
f"-not -path '*/.git/*' "
|
||||
f"2>/dev/null"
|
||||
)
|
||||
|
||||
if find_result.stdout:
|
||||
for file_path in find_result.stdout.strip().split("\n"):
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Check if it's a text file we can read
|
||||
is_text = any(
|
||||
file_path.endswith(ext) for ext in text_extensions
|
||||
) or file_path.endswith("Dockerfile")
|
||||
|
||||
if is_text:
|
||||
try:
|
||||
content = await sandbox.files.read(file_path)
|
||||
# Handle bytes or string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
|
||||
# Extract filename from path
|
||||
file_name = file_path.split("/")[-1]
|
||||
|
||||
# Calculate relative path by stripping working directory
|
||||
relative_path = file_path
|
||||
if file_path.startswith(working_directory):
|
||||
relative_path = file_path[len(working_directory) :]
|
||||
# Remove leading slash if present
|
||||
if relative_path.startswith("/"):
|
||||
relative_path = relative_path[1:]
|
||||
|
||||
files.append(
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path=file_path,
|
||||
relative_path=relative_path,
|
||||
name=file_name,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be read
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# If file extraction fails, return empty results
|
||||
pass
|
||||
|
||||
return files
|
||||
|
||||
def _escape_prompt(self, prompt: str) -> str:
|
||||
"""Escape the prompt for safe shell execution."""
|
||||
# Use single quotes and escape any single quotes in the prompt
|
||||
@@ -617,6 +490,7 @@ class ClaudeCodeBlock(Block):
|
||||
*,
|
||||
e2b_credentials: APIKeyCredentials,
|
||||
anthropic_credentials: APIKeyCredentials,
|
||||
execution_context: "ExecutionContext",
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
@@ -637,6 +511,7 @@ class ClaudeCodeBlock(Block):
|
||||
existing_sandbox_id=input_data.sandbox_id,
|
||||
conversation_history=input_data.conversation_history,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
@@ -20,6 +20,13 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.sandbox_files import (
|
||||
SandboxFileOutput,
|
||||
extract_and_store_sandbox_files,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.utils import ExecutionContext
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -85,6 +92,9 @@ class CodeExecutionResult(MainCodeExecutionResult):
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
# Default working directory in E2B sandboxes
|
||||
WORKING_DIR = "/home/user"
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
@@ -95,14 +105,21 @@ class BaseE2BExecutorMixin:
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
execution_context: Optional["ExecutionContext"] = None,
|
||||
extract_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
|
||||
Args:
|
||||
extract_files: If True and execution_context provided, extract files
|
||||
created/modified during execution and store to workspace.
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
files: list[SandboxFileOutput] = []
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
@@ -118,6 +135,12 @@ class BaseE2BExecutorMixin:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Capture timestamp before execution to scope file extraction
|
||||
start_timestamp = None
|
||||
if extract_files:
|
||||
ts_result = await sandbox.commands.run("date -u +%Y-%m-%dT%H:%M:%S")
|
||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
@@ -133,7 +156,24 @@ class BaseE2BExecutorMixin:
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
# Extract files created/modified during this execution
|
||||
if extract_files and execution_context:
|
||||
files = await extract_and_store_sandbox_files(
|
||||
sandbox=sandbox,
|
||||
working_directory=self.WORKING_DIR,
|
||||
execution_context=execution_context,
|
||||
since_timestamp=start_timestamp,
|
||||
text_only=False, # Include binary files too
|
||||
)
|
||||
|
||||
return (
|
||||
results,
|
||||
text_output,
|
||||
stdout_logs,
|
||||
stderr_logs,
|
||||
sandbox.sandbox_id,
|
||||
files,
|
||||
)
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
@@ -238,6 +278,12 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
files: list[SandboxFileOutput] = SchemaField(
|
||||
description=(
|
||||
"Files created or modified during execution. "
|
||||
"Each file has path, name, content, and workspace_ref (if stored)."
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -259,23 +305,30 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
("files", []),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox, execution_context, extract_files: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
[], # files
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: "ExecutionContext",
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
results, text_output, stdout, stderr, _, files = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
@@ -283,6 +336,8 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
execution_context=execution_context,
|
||||
extract_files=True,
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
@@ -296,6 +351,8 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
# Always yield files (empty list if none)
|
||||
yield "files", [f.model_dump() for f in files]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -393,6 +450,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
[], # files
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -401,7 +459,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
_, text_output, stdout, stderr, sandbox_id, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
@@ -500,6 +558,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
[], # files
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -508,7 +567,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
results, text_output, stdout, stderr, _, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
|
||||
@@ -682,17 +682,219 @@ class ListIsEmptyBlock(Block):
|
||||
yield "is_empty", len(input_data.list) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# List Concatenation Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _validate_list_input(item: Any, index: int) -> str | None:
|
||||
"""Validate that an item is a list. Returns error message or None."""
|
||||
if item is None:
|
||||
return None # None is acceptable, will be skipped
|
||||
if not isinstance(item, list):
|
||||
return (
|
||||
f"Invalid input at index {index}: expected a list, "
|
||||
f"got {type(item).__name__}. "
|
||||
f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _validate_all_lists(lists: List[Any]) -> str | None:
|
||||
"""Validate that all items in a sequence are lists. Returns first error or None."""
|
||||
for idx, item in enumerate(lists):
|
||||
error = _validate_list_input(item, idx)
|
||||
if error is not None and item is not None:
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
def _concatenate_lists_simple(lists: List[List[Any]]) -> List[Any]:
|
||||
"""Concatenate a sequence of lists into a single list, skipping None values."""
|
||||
result: List[Any] = []
|
||||
for lst in lists:
|
||||
if lst is None:
|
||||
continue
|
||||
result.extend(lst)
|
||||
return result
|
||||
|
||||
|
||||
def _flatten_nested_list(nested: List[Any], max_depth: int = -1) -> List[Any]:
|
||||
"""
|
||||
Recursively flatten a nested list structure.
|
||||
|
||||
Args:
|
||||
nested: The list to flatten.
|
||||
max_depth: Maximum recursion depth. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
A flat list with all nested elements extracted.
|
||||
"""
|
||||
result: List[Any] = []
|
||||
_flatten_recursive(nested, result, current_depth=0, max_depth=max_depth)
|
||||
return result
|
||||
|
||||
|
||||
_MAX_FLATTEN_DEPTH = 1000
|
||||
|
||||
|
||||
def _flatten_recursive(
|
||||
items: List[Any],
|
||||
result: List[Any],
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
) -> None:
|
||||
"""Internal recursive helper for flattening nested lists."""
|
||||
if current_depth > _MAX_FLATTEN_DEPTH:
|
||||
raise RecursionError(
|
||||
f"Flattening exceeded maximum depth of {_MAX_FLATTEN_DEPTH} levels. "
|
||||
"Input may be too deeply nested."
|
||||
)
|
||||
for item in items:
|
||||
if isinstance(item, list) and (max_depth == -1 or current_depth < max_depth):
|
||||
_flatten_recursive(item, result, current_depth + 1, max_depth)
|
||||
else:
|
||||
result.append(item)
|
||||
|
||||
|
||||
def _deduplicate_list(items: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Remove duplicate elements from a list, preserving order of first occurrences.
|
||||
|
||||
Args:
|
||||
items: The list to deduplicate.
|
||||
|
||||
Returns:
|
||||
A list with duplicates removed, maintaining original order.
|
||||
"""
|
||||
seen: set = set()
|
||||
result: List[Any] = []
|
||||
for item in items:
|
||||
item_id = _make_hashable(item)
|
||||
if item_id not in seen:
|
||||
seen.add(item_id)
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
def _make_hashable(item: Any):
|
||||
"""
|
||||
Create a hashable representation of any item for deduplication.
|
||||
Converts unhashable types (dicts, lists) into deterministic tuple structures.
|
||||
"""
|
||||
if isinstance(item, dict):
|
||||
return tuple(
|
||||
sorted(
|
||||
((_make_hashable(k), _make_hashable(v)) for k, v in item.items()),
|
||||
key=lambda x: (str(type(x[0])), str(x[0])),
|
||||
)
|
||||
)
|
||||
if isinstance(item, (list, tuple)):
|
||||
return tuple(_make_hashable(i) for i in item)
|
||||
if isinstance(item, set):
|
||||
return frozenset(_make_hashable(i) for i in item)
|
||||
return item
|
||||
|
||||
|
||||
def _filter_none_values(items: List[Any]) -> List[Any]:
|
||||
"""Remove None values from a list."""
|
||||
return [item for item in items if item is not None]
|
||||
|
||||
|
||||
def _compute_nesting_depth(
|
||||
items: Any, current: int = 0, max_depth: int = _MAX_FLATTEN_DEPTH
|
||||
) -> int:
|
||||
"""
|
||||
Compute the maximum nesting depth of a list structure using iteration to avoid RecursionError.
|
||||
|
||||
Uses a stack-based approach to handle deeply nested structures without hitting Python's
|
||||
recursion limit (~1000 levels).
|
||||
"""
|
||||
if not isinstance(items, list):
|
||||
return current
|
||||
|
||||
# Stack contains tuples of (item, depth)
|
||||
stack = [(items, current)]
|
||||
max_observed_depth = current
|
||||
|
||||
while stack:
|
||||
item, depth = stack.pop()
|
||||
|
||||
if depth > max_depth:
|
||||
return depth
|
||||
|
||||
if not isinstance(item, list):
|
||||
max_observed_depth = max(max_observed_depth, depth)
|
||||
continue
|
||||
|
||||
if len(item) == 0:
|
||||
max_observed_depth = max(max_observed_depth, depth + 1)
|
||||
continue
|
||||
|
||||
# Add all children to stack with incremented depth
|
||||
for child in item:
|
||||
stack.append((child, depth + 1))
|
||||
|
||||
return max_observed_depth
|
||||
|
||||
|
||||
def _interleave_lists(lists: List[List[Any]]) -> List[Any]:
|
||||
"""
|
||||
Interleave elements from multiple lists in round-robin fashion.
|
||||
Example: [[1,2,3], [a,b], [x,y,z]] -> [1, a, x, 2, b, y, 3, z]
|
||||
"""
|
||||
if not lists:
|
||||
return []
|
||||
filtered = [lst for lst in lists if lst is not None]
|
||||
if not filtered:
|
||||
return []
|
||||
result: List[Any] = []
|
||||
max_len = max(len(lst) for lst in filtered)
|
||||
for i in range(max_len):
|
||||
for lst in filtered:
|
||||
if i < len(lst):
|
||||
result.append(lst[i])
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# List Concatenation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ConcatenateListsBlock(Block):
|
||||
"""
|
||||
Concatenates two or more lists into a single list.
|
||||
|
||||
This block accepts a list of lists and combines all their elements
|
||||
in order into one flat output list. It supports options for
|
||||
deduplication and None-filtering to provide flexible list merging
|
||||
capabilities for workflow pipelines.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
lists: List[List[Any]] = SchemaField(
|
||||
description="A list of lists to concatenate together. All lists will be combined in order into a single list.",
|
||||
placeholder="e.g., [[1, 2], [3, 4], [5, 6]]",
|
||||
)
|
||||
deduplicate: bool = SchemaField(
|
||||
description="If True, remove duplicate elements from the concatenated result while preserving order.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
remove_none: bool = SchemaField(
|
||||
description="If True, remove None values from the concatenated result.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
concatenated_list: List[Any] = SchemaField(
|
||||
description="The concatenated list containing all elements from all input lists in order."
|
||||
)
|
||||
length: int = SchemaField(
|
||||
description="The total number of elements in the concatenated list."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if concatenation failed due to invalid input types."
|
||||
)
|
||||
@@ -700,7 +902,7 @@ class ConcatenateListsBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3cf9298b-5817-4141-9d80-7c2cc5199c8e",
|
||||
description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order.",
|
||||
description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order. Supports optional deduplication and None removal.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ConcatenateListsBlock.Input,
|
||||
output_schema=ConcatenateListsBlock.Output,
|
||||
@@ -709,29 +911,497 @@ class ConcatenateListsBlock(Block):
|
||||
{"lists": [["a", "b"], ["c"], ["d", "e", "f"]]},
|
||||
{"lists": [[1, 2], []]},
|
||||
{"lists": []},
|
||||
{"lists": [[1, 2, 2, 3], [3, 4]], "deduplicate": True},
|
||||
{"lists": [[1, None, 2], [None, 3]], "remove_none": True},
|
||||
],
|
||||
test_output=[
|
||||
("concatenated_list", [1, 2, 3, 4, 5, 6]),
|
||||
("length", 6),
|
||||
("concatenated_list", ["a", "b", "c", "d", "e", "f"]),
|
||||
("length", 6),
|
||||
("concatenated_list", [1, 2]),
|
||||
("length", 2),
|
||||
("concatenated_list", []),
|
||||
("length", 0),
|
||||
("concatenated_list", [1, 2, 3, 4]),
|
||||
("length", 4),
|
||||
("concatenated_list", [1, 2, 3]),
|
||||
("length", 3),
|
||||
],
|
||||
)
|
||||
|
||||
def _validate_inputs(self, lists: List[Any]) -> str | None:
|
||||
return _validate_all_lists(lists)
|
||||
|
||||
def _perform_concatenation(self, lists: List[List[Any]]) -> List[Any]:
|
||||
return _concatenate_lists_simple(lists)
|
||||
|
||||
def _apply_deduplication(self, items: List[Any]) -> List[Any]:
|
||||
return _deduplicate_list(items)
|
||||
|
||||
def _apply_none_removal(self, items: List[Any]) -> List[Any]:
|
||||
return _filter_none_values(items)
|
||||
|
||||
def _post_process(
|
||||
self, items: List[Any], deduplicate: bool, remove_none: bool
|
||||
) -> List[Any]:
|
||||
"""Apply all post-processing steps to the concatenated result."""
|
||||
result = items
|
||||
if remove_none:
|
||||
result = self._apply_none_removal(result)
|
||||
if deduplicate:
|
||||
result = self._apply_deduplication(result)
|
||||
return result
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
concatenated = []
|
||||
for idx, lst in enumerate(input_data.lists):
|
||||
if lst is None:
|
||||
# Skip None values to avoid errors
|
||||
continue
|
||||
if not isinstance(lst, list):
|
||||
# Type validation: each item must be a list
|
||||
# Strings are iterable and would cause extend() to iterate character-by-character
|
||||
# Non-iterable types would raise TypeError
|
||||
yield "error", (
|
||||
f"Invalid input at index {idx}: expected a list, got {type(lst).__name__}. "
|
||||
f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])."
|
||||
)
|
||||
return
|
||||
concatenated.extend(lst)
|
||||
yield "concatenated_list", concatenated
|
||||
# Validate all inputs are lists
|
||||
validation_error = self._validate_inputs(input_data.lists)
|
||||
if validation_error is not None:
|
||||
yield "error", validation_error
|
||||
return
|
||||
|
||||
# Perform concatenation
|
||||
concatenated = self._perform_concatenation(input_data.lists)
|
||||
|
||||
# Apply post-processing
|
||||
result = self._post_process(
|
||||
concatenated, input_data.deduplicate, input_data.remove_none
|
||||
)
|
||||
|
||||
yield "concatenated_list", result
|
||||
yield "length", len(result)
|
||||
|
||||
|
||||
class FlattenListBlock(Block):
|
||||
"""
|
||||
Flattens a nested list structure into a single flat list.
|
||||
|
||||
This block takes a list that may contain nested lists at any depth
|
||||
and produces a single-level list with all leaf elements. Useful
|
||||
for normalizing data structures from multiple sources that may
|
||||
have varying levels of nesting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
nested_list: List[Any] = SchemaField(
|
||||
description="A potentially nested list to flatten into a single-level list.",
|
||||
placeholder="e.g., [[1, [2, 3]], [4, [5, [6]]]]",
|
||||
)
|
||||
max_depth: int = SchemaField(
|
||||
description="Maximum depth to flatten. -1 means flatten completely. 1 means flatten only one level.",
|
||||
default=-1,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
flattened_list: List[Any] = SchemaField(
|
||||
description="The flattened list with all nested elements extracted."
|
||||
)
|
||||
length: int = SchemaField(
|
||||
description="The number of elements in the flattened list."
|
||||
)
|
||||
original_depth: int = SchemaField(
|
||||
description="The maximum nesting depth of the original input list."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if flattening failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cc45bb0f-d035-4756-96a7-fe3e36254b4d",
|
||||
description="Flattens a nested list structure into a single flat list. Supports configurable maximum flattening depth.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FlattenListBlock.Input,
|
||||
output_schema=FlattenListBlock.Output,
|
||||
test_input=[
|
||||
{"nested_list": [[1, 2], [3, [4, 5]]]},
|
||||
{"nested_list": [1, [2, [3, [4]]]]},
|
||||
{"nested_list": [1, [2, [3, [4]]], 5], "max_depth": 1},
|
||||
{"nested_list": []},
|
||||
{"nested_list": [1, 2, 3]},
|
||||
],
|
||||
test_output=[
|
||||
("flattened_list", [1, 2, 3, 4, 5]),
|
||||
("length", 5),
|
||||
("original_depth", 3),
|
||||
("flattened_list", [1, 2, 3, 4]),
|
||||
("length", 4),
|
||||
("original_depth", 4),
|
||||
("flattened_list", [1, 2, [3, [4]], 5]),
|
||||
("length", 4),
|
||||
("original_depth", 4),
|
||||
("flattened_list", []),
|
||||
("length", 0),
|
||||
("original_depth", 1),
|
||||
("flattened_list", [1, 2, 3]),
|
||||
("length", 3),
|
||||
("original_depth", 1),
|
||||
],
|
||||
)
|
||||
|
||||
def _compute_depth(self, items: List[Any]) -> int:
|
||||
"""Compute the nesting depth of the input list."""
|
||||
return _compute_nesting_depth(items)
|
||||
|
||||
def _flatten(self, items: List[Any], max_depth: int) -> List[Any]:
|
||||
"""Flatten the list to the specified depth."""
|
||||
return _flatten_nested_list(items, max_depth=max_depth)
|
||||
|
||||
def _validate_max_depth(self, max_depth: int) -> str | None:
|
||||
"""Validate the max_depth parameter."""
|
||||
if max_depth < -1:
|
||||
return f"max_depth must be -1 (unlimited) or a non-negative integer, got {max_depth}"
|
||||
return None
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Validate max_depth
|
||||
depth_error = self._validate_max_depth(input_data.max_depth)
|
||||
if depth_error is not None:
|
||||
yield "error", depth_error
|
||||
return
|
||||
|
||||
original_depth = self._compute_depth(input_data.nested_list)
|
||||
flattened = self._flatten(input_data.nested_list, input_data.max_depth)
|
||||
|
||||
yield "flattened_list", flattened
|
||||
yield "length", len(flattened)
|
||||
yield "original_depth", original_depth
|
||||
|
||||
|
||||
class InterleaveListsBlock(Block):
|
||||
"""
|
||||
Interleaves elements from multiple lists in round-robin fashion.
|
||||
|
||||
Given multiple input lists, this block takes one element from each
|
||||
list in turn, producing an output where elements alternate between
|
||||
sources. Lists of different lengths are handled gracefully - shorter
|
||||
lists simply stop contributing once exhausted.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
lists: List[List[Any]] = SchemaField(
|
||||
description="A list of lists to interleave. Elements will be taken in round-robin order.",
|
||||
placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
interleaved_list: List[Any] = SchemaField(
|
||||
description="The interleaved list with elements alternating from each input list."
|
||||
)
|
||||
length: int = SchemaField(
|
||||
description="The total number of elements in the interleaved list."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if interleaving failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9f616084-1d9f-4f8e-bc00-5b9d2a75cd75",
|
||||
description="Interleaves elements from multiple lists in round-robin fashion, alternating between sources.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=InterleaveListsBlock.Input,
|
||||
output_schema=InterleaveListsBlock.Output,
|
||||
test_input=[
|
||||
{"lists": [[1, 2, 3], ["a", "b", "c"]]},
|
||||
{"lists": [[1, 2, 3], ["a", "b"], ["x", "y", "z"]]},
|
||||
{"lists": [[1], [2], [3]]},
|
||||
{"lists": []},
|
||||
],
|
||||
test_output=[
|
||||
("interleaved_list", [1, "a", 2, "b", 3, "c"]),
|
||||
("length", 6),
|
||||
("interleaved_list", [1, "a", "x", 2, "b", "y", 3, "z"]),
|
||||
("length", 8),
|
||||
("interleaved_list", [1, 2, 3]),
|
||||
("length", 3),
|
||||
("interleaved_list", []),
|
||||
("length", 0),
|
||||
],
|
||||
)
|
||||
|
||||
def _validate_inputs(self, lists: List[Any]) -> str | None:
|
||||
return _validate_all_lists(lists)
|
||||
|
||||
def _interleave(self, lists: List[List[Any]]) -> List[Any]:
|
||||
return _interleave_lists(lists)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
validation_error = self._validate_inputs(input_data.lists)
|
||||
if validation_error is not None:
|
||||
yield "error", validation_error
|
||||
return
|
||||
|
||||
result = self._interleave(input_data.lists)
|
||||
yield "interleaved_list", result
|
||||
yield "length", len(result)
|
||||
|
||||
|
||||
class ZipListsBlock(Block):
|
||||
"""
|
||||
Zips multiple lists together into a list of grouped tuples/lists.
|
||||
|
||||
Takes two or more input lists and combines corresponding elements
|
||||
into sub-lists. For example, zipping [1,2,3] and ['a','b','c']
|
||||
produces [[1,'a'], [2,'b'], [3,'c']]. Supports both truncating
|
||||
to shortest list and padding to longest list with a fill value.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
lists: List[List[Any]] = SchemaField(
|
||||
description="A list of lists to zip together. Corresponding elements will be grouped.",
|
||||
placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]",
|
||||
)
|
||||
pad_to_longest: bool = SchemaField(
|
||||
description="If True, pad shorter lists with fill_value to match the longest list. If False, truncate to shortest.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
fill_value: Any = SchemaField(
|
||||
description="Value to use for padding when pad_to_longest is True.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
zipped_list: List[List[Any]] = SchemaField(
|
||||
description="The zipped list of grouped elements."
|
||||
)
|
||||
length: int = SchemaField(
|
||||
description="The number of groups in the zipped result."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if zipping failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d0e684f-5cb9-4c4b-b8d1-47a0860e0c07",
|
||||
description="Zips multiple lists together into a list of grouped elements. Supports padding to longest or truncating to shortest.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ZipListsBlock.Input,
|
||||
output_schema=ZipListsBlock.Output,
|
||||
test_input=[
|
||||
{"lists": [[1, 2, 3], ["a", "b", "c"]]},
|
||||
{"lists": [[1, 2, 3], ["a", "b"]]},
|
||||
{
|
||||
"lists": [[1, 2], ["a", "b", "c"]],
|
||||
"pad_to_longest": True,
|
||||
"fill_value": 0,
|
||||
},
|
||||
{"lists": []},
|
||||
],
|
||||
test_output=[
|
||||
("zipped_list", [[1, "a"], [2, "b"], [3, "c"]]),
|
||||
("length", 3),
|
||||
("zipped_list", [[1, "a"], [2, "b"]]),
|
||||
("length", 2),
|
||||
("zipped_list", [[1, "a"], [2, "b"], [0, "c"]]),
|
||||
("length", 3),
|
||||
("zipped_list", []),
|
||||
("length", 0),
|
||||
],
|
||||
)
|
||||
|
||||
def _validate_inputs(self, lists: List[Any]) -> str | None:
|
||||
return _validate_all_lists(lists)
|
||||
|
||||
def _zip_truncate(self, lists: List[List[Any]]) -> List[List[Any]]:
|
||||
"""Zip lists, truncating to shortest."""
|
||||
filtered = [lst for lst in lists if lst is not None]
|
||||
if not filtered:
|
||||
return []
|
||||
return [list(group) for group in zip(*filtered)]
|
||||
|
||||
def _zip_pad(self, lists: List[List[Any]], fill_value: Any) -> List[List[Any]]:
|
||||
"""Zip lists, padding shorter ones with fill_value."""
|
||||
if not lists:
|
||||
return []
|
||||
lists = [lst for lst in lists if lst is not None]
|
||||
if not lists:
|
||||
return []
|
||||
max_len = max(len(lst) for lst in lists)
|
||||
result: List[List[Any]] = []
|
||||
for i in range(max_len):
|
||||
group: List[Any] = []
|
||||
for lst in lists:
|
||||
if i < len(lst):
|
||||
group.append(lst[i])
|
||||
else:
|
||||
group.append(fill_value)
|
||||
result.append(group)
|
||||
return result
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
validation_error = self._validate_inputs(input_data.lists)
|
||||
if validation_error is not None:
|
||||
yield "error", validation_error
|
||||
return
|
||||
|
||||
if not input_data.lists:
|
||||
yield "zipped_list", []
|
||||
yield "length", 0
|
||||
return
|
||||
|
||||
if input_data.pad_to_longest:
|
||||
result = self._zip_pad(input_data.lists, input_data.fill_value)
|
||||
else:
|
||||
result = self._zip_truncate(input_data.lists)
|
||||
|
||||
yield "zipped_list", result
|
||||
yield "length", len(result)
|
||||
|
||||
|
||||
class ListDifferenceBlock(Block):
|
||||
"""
|
||||
Computes the difference between two lists (elements in the first
|
||||
list that are not in the second list).
|
||||
|
||||
This is useful for finding items that exist in one dataset but
|
||||
not in another, such as finding new items, missing items, or
|
||||
items that need to be processed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
list_a: List[Any] = SchemaField(
|
||||
description="The primary list to check elements from.",
|
||||
placeholder="e.g., [1, 2, 3, 4, 5]",
|
||||
)
|
||||
list_b: List[Any] = SchemaField(
|
||||
description="The list to subtract. Elements found here will be removed from list_a.",
|
||||
placeholder="e.g., [3, 4, 5, 6]",
|
||||
)
|
||||
symmetric: bool = SchemaField(
|
||||
description="If True, compute symmetric difference (elements in either list but not both).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
difference: List[Any] = SchemaField(
|
||||
description="Elements from list_a not found in list_b (or symmetric difference if enabled)."
|
||||
)
|
||||
length: int = SchemaField(
|
||||
description="The number of elements in the difference result."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="05309873-9d61-447e-96b5-b804e2511829",
|
||||
description="Computes the difference between two lists. Returns elements in the first list not found in the second, or symmetric difference.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ListDifferenceBlock.Input,
|
||||
output_schema=ListDifferenceBlock.Output,
|
||||
test_input=[
|
||||
{"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]},
|
||||
{
|
||||
"list_a": [1, 2, 3, 4, 5],
|
||||
"list_b": [3, 4, 5, 6, 7],
|
||||
"symmetric": True,
|
||||
},
|
||||
{"list_a": ["a", "b", "c"], "list_b": ["b"]},
|
||||
{"list_a": [], "list_b": [1, 2, 3]},
|
||||
],
|
||||
test_output=[
|
||||
("difference", [1, 2]),
|
||||
("length", 2),
|
||||
("difference", [1, 2, 6, 7]),
|
||||
("length", 4),
|
||||
("difference", ["a", "c"]),
|
||||
("length", 2),
|
||||
("difference", []),
|
||||
("length", 0),
|
||||
],
|
||||
)
|
||||
|
||||
def _compute_difference(self, list_a: List[Any], list_b: List[Any]) -> List[Any]:
|
||||
"""Compute elements in list_a not in list_b."""
|
||||
b_hashes = {_make_hashable(item) for item in list_b}
|
||||
return [item for item in list_a if _make_hashable(item) not in b_hashes]
|
||||
|
||||
def _compute_symmetric_difference(
|
||||
self, list_a: List[Any], list_b: List[Any]
|
||||
) -> List[Any]:
|
||||
"""Compute elements in either list but not both."""
|
||||
a_hashes = {_make_hashable(item) for item in list_a}
|
||||
b_hashes = {_make_hashable(item) for item in list_b}
|
||||
only_in_a = [item for item in list_a if _make_hashable(item) not in b_hashes]
|
||||
only_in_b = [item for item in list_b if _make_hashable(item) not in a_hashes]
|
||||
return only_in_a + only_in_b
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.symmetric:
|
||||
result = self._compute_symmetric_difference(
|
||||
input_data.list_a, input_data.list_b
|
||||
)
|
||||
else:
|
||||
result = self._compute_difference(input_data.list_a, input_data.list_b)
|
||||
|
||||
yield "difference", result
|
||||
yield "length", len(result)
|
||||
|
||||
|
||||
class ListIntersectionBlock(Block):
|
||||
"""
|
||||
Computes the intersection of two lists (elements present in both lists).
|
||||
|
||||
This is useful for finding common items between two datasets,
|
||||
such as shared tags, mutual connections, or overlapping categories.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
list_a: List[Any] = SchemaField(
|
||||
description="The first list to intersect.",
|
||||
placeholder="e.g., [1, 2, 3, 4, 5]",
|
||||
)
|
||||
list_b: List[Any] = SchemaField(
|
||||
description="The second list to intersect.",
|
||||
placeholder="e.g., [3, 4, 5, 6, 7]",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
intersection: List[Any] = SchemaField(
|
||||
description="Elements present in both list_a and list_b."
|
||||
)
|
||||
length: int = SchemaField(
|
||||
description="The number of elements in the intersection."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6eb08b6-dbe3-411b-b9b4-2508cb311a1f",
|
||||
description="Computes the intersection of two lists, returning only elements present in both.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ListIntersectionBlock.Input,
|
||||
output_schema=ListIntersectionBlock.Output,
|
||||
test_input=[
|
||||
{"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]},
|
||||
{"list_a": ["a", "b", "c"], "list_b": ["c", "d", "e"]},
|
||||
{"list_a": [1, 2], "list_b": [3, 4]},
|
||||
{"list_a": [], "list_b": [1, 2, 3]},
|
||||
],
|
||||
test_output=[
|
||||
("intersection", [3, 4, 5]),
|
||||
("length", 3),
|
||||
("intersection", ["c"]),
|
||||
("length", 1),
|
||||
("intersection", []),
|
||||
("length", 0),
|
||||
("intersection", []),
|
||||
("length", 0),
|
||||
],
|
||||
)
|
||||
|
||||
def _compute_intersection(self, list_a: List[Any], list_b: List[Any]) -> List[Any]:
|
||||
"""Compute elements present in both lists, preserving order from list_a."""
|
||||
b_hashes = {_make_hashable(item) for item in list_b}
|
||||
seen: set = set()
|
||||
result: List[Any] = []
|
||||
for item in list_a:
|
||||
h = _make_hashable(item)
|
||||
if h in b_hashes and h not in seen:
|
||||
result.append(item)
|
||||
seen.add(h)
|
||||
return result
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
result = self._compute_intersection(input_data.list_a, input_data.list_b)
|
||||
yield "intersection", result
|
||||
yield "length", len(result)
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.blocks.jina._auth import (
|
||||
from backend.blocks.search import GetRequest
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -110,7 +111,12 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
if input_data.raw_content:
|
||||
url = input_data.url
|
||||
try:
|
||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||
url = parsed_url.geturl()
|
||||
except ValueError as e:
|
||||
yield "error", f"Invalid URL: {e}"
|
||||
return
|
||||
headers = {}
|
||||
else:
|
||||
url = f"https://r.jina.ai/{input_data.url}"
|
||||
@@ -119,5 +125,20 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
content = await self.get_request(url, json=False, headers=headers)
|
||||
try:
|
||||
content = await self.get_request(url, json=False, headers=headers)
|
||||
except HTTPClientError as e:
|
||||
yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}"
|
||||
return
|
||||
except HTTPServerError as e:
|
||||
yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}"
|
||||
return
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch {input_data.url}: {e}"
|
||||
return
|
||||
|
||||
if not content:
|
||||
yield "error", f"No content returned for {input_data.url}"
|
||||
return
|
||||
|
||||
yield "content", content
|
||||
|
||||
300
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
300
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) Tool Block.
|
||||
|
||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
||||
dropdown and the input/output schema adapts dynamically.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.data.block import BlockInput, BlockOutput
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="test-mcp-cred",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("mock-mcp-token"),
|
||||
refresh_token=SecretStr("mock-refresh"),
|
||||
scopes=[],
|
||||
title="Mock MCP credential",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
||||
|
||||
|
||||
class MCPToolBlock(Block):
|
||||
"""
|
||||
A block that connects to an MCP server, lets the user pick a tool,
|
||||
and executes it with dynamic input/output schema.
|
||||
|
||||
The flow:
|
||||
1. User provides an MCP server URL (and optional credentials)
|
||||
2. Frontend calls the backend to get tool list from that URL
|
||||
3. User selects a tool from a dropdown (available_tools)
|
||||
4. The block's input schema updates to reflect the selected tool's parameters
|
||||
5. On execution, the block calls the MCP server to run the tool
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
server_url: str = SchemaField(
|
||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||
placeholder="https://mcp.example.com/mcp",
|
||||
)
|
||||
credentials: MCPCredentials = CredentialsField(
|
||||
discriminator="server_url",
|
||||
description="MCP server OAuth credentials",
|
||||
default={},
|
||||
)
|
||||
selected_tool: str = SchemaField(
|
||||
description="The MCP tool to execute",
|
||||
placeholder="Select a tool",
|
||||
default="",
|
||||
)
|
||||
tool_input_schema: dict[str, Any] = SchemaField(
|
||||
description="JSON Schema for the selected tool's input parameters. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
"The fields here are defined by the tool's input schema.",
|
||||
default={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
||||
return data.get("tool_input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
||||
return data.get("tool_arguments", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
"""Check which required tool arguments are missing."""
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return set(required_fields) - set(tool_arguments)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
"""Validate tool_arguments against the tool's input schema."""
|
||||
tool_schema = cls.get_input_schema(data)
|
||||
if not tool_schema:
|
||||
return None
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
||||
error: str = SchemaField(description="Error message if the tool call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||
description="Connect to any MCP server and execute its tools. "
|
||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MCPToolBlock.Input,
|
||||
output_schema=MCPToolBlock.Output,
|
||||
block_type=BlockType.MCP_TOOL,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"selected_tool": "get_weather",
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
{"weather": "sunny", "temperature": 20},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_call_mcp_tool": lambda *a, **kw: {
|
||||
"weather": "sunny",
|
||||
"temperature": 20,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def _call_mcp_tool(
|
||||
self,
|
||||
server_url: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
auth_token: str | None = None,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
||||
client = MCPClient(server_url, auth_token=auth_token)
|
||||
await client.initialize()
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
if result.is_error:
|
||||
error_text = ""
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
error_text += item.get("text", "")
|
||||
raise MCPClientError(
|
||||
f"MCP tool '{tool_name}' returned an error: "
|
||||
f"{error_text or 'Unknown error'}"
|
||||
)
|
||||
|
||||
# Extract text content from the result
|
||||
output_parts = []
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
# Try to parse as JSON for structured output
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item.get("type") == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item.get("type") == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
# If single result, unwrap
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
@staticmethod
|
||||
async def _auto_lookup_credential(
|
||||
user_id: str, server_url: str
|
||||
) -> "OAuth2Credentials | None":
|
||||
"""Auto-lookup stored MCP credential for a server URL.
|
||||
|
||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||
set (e.g. nodes created before the credential field was wired up).
|
||||
"""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
mgr = IntegrationCredentialsManager()
|
||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||
user_id, ProviderName.MCP.value
|
||||
)
|
||||
best: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = cred
|
||||
if best:
|
||||
best = await mgr.refresh_if_needed(user_id, best)
|
||||
logger.info(
|
||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||
)
|
||||
return best
|
||||
except Exception:
|
||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
credentials: OAuth2Credentials | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.server_url:
|
||||
yield "error", "MCP server URL is required"
|
||||
return
|
||||
|
||||
if not input_data.selected_tool:
|
||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||
return
|
||||
|
||||
# Validate required tool arguments before calling the server.
|
||||
# The executor-level validation is bypassed for MCP blocks because
|
||||
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
|
||||
# from the validation context.
|
||||
required = set(input_data.tool_input_schema.get("required", []))
|
||||
if required:
|
||||
missing = required - set(input_data.tool_arguments.keys())
|
||||
if missing:
|
||||
yield "error", (
|
||||
f"Missing required argument(s): {', '.join(sorted(missing))}. "
|
||||
f"Please fill in all required fields marked with * in the block form."
|
||||
)
|
||||
return
|
||||
|
||||
# If no credentials were injected by the executor (e.g. legacy nodes
|
||||
# that don't have the credentials field set), try to auto-lookup
|
||||
# the stored MCP credential for this server URL.
|
||||
if credentials is None:
|
||||
credentials = await self._auto_lookup_credential(
|
||||
user_id, input_data.server_url
|
||||
)
|
||||
|
||||
auth_token = (
|
||||
credentials.access_token.get_secret_value() if credentials else None
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._call_mcp_tool(
|
||||
server_url=input_data.server_url,
|
||||
tool_name=input_data.selected_tool,
|
||||
arguments=input_data.tool_arguments,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
yield "result", result
|
||||
except MCPClientError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
logger.exception(f"MCP tool call failed: {e}")
|
||||
yield "error", f"MCP tool call failed: {str(e)}"
|
||||
323
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
323
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) HTTP client.
|
||||
|
||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
||||
|
||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
||||
|
||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPTool:
|
||||
"""Represents an MCP tool discovered from a server."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPCallResult:
|
||||
"""Result from calling an MCP tool."""
|
||||
|
||||
content: list[dict[str, Any]] = field(default_factory=list)
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
class MCPClientError(Exception):
|
||||
"""Raised when an MCP protocol error occurs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
Async HTTP client for the MCP Streamable HTTP transport.
|
||||
|
||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
||||
Supports optional Bearer token authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
auth_token: str | None = None,
|
||||
):
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.auth_token = auth_token
|
||||
self._request_id = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
if self._session_id:
|
||||
headers["Mcp-Session-Id"] = self._session_id
|
||||
return headers
|
||||
|
||||
def _build_jsonrpc_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
req: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"id": self._next_id(),
|
||||
}
|
||||
if params is not None:
|
||||
req["params"] = params
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
||||
|
||||
MCP servers may return responses as SSE with format:
|
||||
event: message
|
||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
||||
|
||||
We extract the last `data:` line that contains a JSON-RPC response
|
||||
(i.e. has an "id" field), which is the reply to our request.
|
||||
"""
|
||||
last_data: dict[str, Any] | None = None
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("data:"):
|
||||
payload = stripped[len("data:") :].strip()
|
||||
if not payload:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
||||
if isinstance(parsed, dict) and "id" in parsed:
|
||||
last_data = parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if last_data is None:
|
||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
||||
return last_data
|
||||
|
||||
async def _send_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
||||
|
||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
||||
as required by the MCP Streamable HTTP transport specification.
|
||||
"""
|
||||
payload = self._build_jsonrpc_request(method, params)
|
||||
headers = self._build_headers()
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=True,
|
||||
extra_headers=headers,
|
||||
)
|
||||
response = await requests.post(self.server_url, json=payload)
|
||||
|
||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
||||
session_id = response.headers.get("Mcp-Session-Id")
|
||||
if session_id:
|
||||
self._session_id = session_id
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
body = self._parse_sse_response(response.text())
|
||||
else:
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception as e:
|
||||
raise MCPClientError(
|
||||
f"MCP server returned non-JSON response: {e}"
|
||||
) from e
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise MCPClientError(
|
||||
f"MCP server returned unexpected JSON type: {type(body).__name__}"
|
||||
)
|
||||
|
||||
# Handle JSON-RPC error
|
||||
if "error" in body:
|
||||
error = body["error"]
|
||||
if isinstance(error, dict):
|
||||
raise MCPClientError(
|
||||
f"MCP server error [{error.get('code', '?')}]: "
|
||||
f"{error.get('message', 'Unknown error')}"
|
||||
)
|
||||
raise MCPClientError(f"MCP server error: {error}")
|
||||
|
||||
return body.get("result")
|
||||
|
||||
async def _send_notification(self, method: str) -> None:
|
||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||
headers = self._build_headers()
|
||||
notification = {"jsonrpc": "2.0", "method": method}
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
extra_headers=headers,
|
||||
)
|
||||
await requests.post(self.server_url, json=notification)
|
||||
|
||||
async def discover_auth(self) -> dict[str, Any] | None:
|
||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
||||
|
||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
||||
a dict with:
|
||||
- ``authorization_servers``: list of authorization server URLs
|
||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
||||
- ``scopes_supported``: optional list of supported scopes
|
||||
|
||||
The caller can then fetch the authorization server metadata to get
|
||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(self.server_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
||||
path = parsed.path.rstrip("/")
|
||||
candidates = []
|
||||
if path and path != "/":
|
||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
)
|
||||
for url in candidates:
|
||||
try:
|
||||
resp = await requests.get(url)
|
||||
if resp.status == 200:
|
||||
data = resp.json()
|
||||
if isinstance(data, dict) and "authorization_servers" in data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def discover_auth_server_metadata(
|
||||
self, auth_server_url: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Given an authorization server URL, returns a dict with:
|
||||
- ``authorization_endpoint``
|
||||
- ``token_endpoint``
|
||||
- ``registration_endpoint`` (for dynamic client registration)
|
||||
- ``scopes_supported``
|
||||
- ``code_challenge_methods_supported``
|
||||
- etc.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(auth_server_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
||||
candidates = []
|
||||
if path and path != "/":
|
||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
)
|
||||
for url in candidates:
|
||||
try:
|
||||
resp = await requests.get(url)
|
||||
if resp.status == 200:
|
||||
data = resp.json()
|
||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def initialize(self) -> dict[str, Any]:
|
||||
"""
|
||||
Send the MCP initialize request.
|
||||
|
||||
This is required by the MCP protocol before any other requests.
|
||||
Returns the server's capabilities.
|
||||
"""
|
||||
result = await self._send_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
||||
},
|
||||
)
|
||||
# Send initialized notification (no response expected)
|
||||
await self._send_notification("notifications/initialized")
|
||||
|
||||
return result or {}
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
Discover available tools from the MCP server.
|
||||
|
||||
Returns a list of MCPTool objects with name, description, and input schema.
|
||||
"""
|
||||
result = await self._send_request("tools/list")
|
||||
if not result or "tools" not in result:
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_data in result["tools"]:
|
||||
tools.append(
|
||||
MCPTool(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description", ""),
|
||||
input_schema=tool_data.get("inputSchema", {}),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> MCPCallResult:
|
||||
"""
|
||||
Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to call.
|
||||
arguments: The arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
MCPCallResult with the tool's response content.
|
||||
"""
|
||||
result = await self._send_request(
|
||||
"tools/call",
|
||||
{"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
if not result:
|
||||
return MCPCallResult(is_error=True)
|
||||
|
||||
return MCPCallResult(
|
||||
content=result.get("content", []),
|
||||
is_error=result.get("isError", False),
|
||||
)
|
||||
204
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
204
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
||||
|
||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
||||
This handler accepts those endpoints at construction time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
||||
|
||||
Construction requires the authorization and token endpoint URLs,
|
||||
which are obtained via MCP OAuth metadata discovery
|
||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
||||
"""
|
||||
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
revoke_url: str | None = None,
|
||||
resource_url: str | None = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.revoke_url = revoke_url
|
||||
self.resource_url = resource_url
|
||||
|
||||
def get_login_url(
|
||||
self,
|
||||
scopes: list[str],
|
||||
state: str,
|
||||
code_challenge: Optional[str],
|
||||
) -> str:
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": state,
|
||||
}
|
||||
if scopes:
|
||||
params["scope"] = " ".join(scopes)
|
||||
# PKCE (S256) — included when the caller provides a code_challenge
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
# MCP spec requires resource indicator (RFC 8707)
|
||||
if self.resource_url:
|
||||
params["resource"] = self.resource_url
|
||||
|
||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self,
|
||||
code: str,
|
||||
scopes: list[str],
|
||||
code_verifier: Optional[str],
|
||||
) -> OAuth2Credentials:
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
if code_verifier:
|
||||
data["code_verifier"] = code_verifier
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
if "access_token" not in tokens:
|
||||
raise RuntimeError("OAuth token response missing 'access_token' field")
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
access_token=SecretStr(tokens["access_token"]),
|
||||
refresh_token=(
|
||||
SecretStr(tokens["refresh_token"])
|
||||
if tokens.get("refresh_token")
|
||||
else None
|
||||
),
|
||||
access_token_expires_at=now + expires_in if expires_in else None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=scopes,
|
||||
metadata={
|
||||
"mcp_token_url": self.token_url,
|
||||
"mcp_resource_url": self.resource_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
if "access_token" not in tokens:
|
||||
raise RuntimeError("OAuth refresh response missing 'access_token' field")
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
access_token=SecretStr(tokens["access_token"]),
|
||||
refresh_token=(
|
||||
SecretStr(tokens["refresh_token"])
|
||||
if tokens.get("refresh_token")
|
||||
else credentials.refresh_token
|
||||
),
|
||||
access_token_expires_at=now + expires_in if expires_in else None,
|
||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
||||
scopes=credentials.scopes,
|
||||
metadata=credentials.metadata,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not self.revoke_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
await Requests().post(
|
||||
self.revoke_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
||||
return False
|
||||
109
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
109
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
End-to-end tests against a real public MCP server.
|
||||
|
||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
||||
which is publicly accessible without authentication and returns SSE responses.
|
||||
|
||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
||||
independently of the rest of the test suite (they require network access).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
|
||||
# Public MCP server that requires no authentication
|
||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
||||
|
||||
# Skip all tests in this module unless RUN_E2E env var is set
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
|
||||
)
|
||||
|
||||
|
||||
class TestRealMCPServer:
|
||||
"""Tests against the live OpenAI docs MCP server."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_initialize(self):
|
||||
"""Verify we can complete the MCP handshake with a real server."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
assert "serverInfo" in result
|
||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
||||
assert "tools" in result.get("capabilities", {})
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools(self):
|
||||
"""Verify we can discover tools from a real MCP server."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
# These tools are documented and should be stable
|
||||
assert "search_openai_docs" in tool_names
|
||||
assert "list_openai_docs" in tool_names
|
||||
assert "fetch_openai_doc" in tool_names
|
||||
|
||||
# Verify schema structure
|
||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
||||
assert "query" in search_tool.input_schema.get("properties", {})
|
||||
assert "query" in search_tool.input_schema.get("required", [])
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_list_api_endpoints(self):
|
||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("list_api_endpoints", {})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) >= 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert "paths" in data or "urls" in data
|
||||
# The OpenAI API should have many endpoints
|
||||
total = data.get("total", len(data.get("paths", [])))
|
||||
assert total > 50
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_search(self):
|
||||
"""Search for docs and verify we get results."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
result = await client.call_tool(
|
||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
||||
)
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) >= 1
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_sse_response_handling(self):
|
||||
"""Verify the client correctly handles SSE responses from a real server.
|
||||
|
||||
This is the key test — our local test server returns JSON,
|
||||
but real MCP servers typically return SSE. This proves the
|
||||
SSE parsing works end-to-end.
|
||||
"""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
# initialize() internally calls _send_request which must parse SSE
|
||||
result = await client.initialize()
|
||||
|
||||
# If we got here without error, SSE parsing works
|
||||
assert isinstance(result, dict)
|
||||
assert "protocolVersion" in result
|
||||
|
||||
# Also verify list_tools works (another SSE response)
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) > 0
|
||||
assert all(hasattr(t, "name") for t in tools)
|
||||
389
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
389
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
||||
|
||||
These tests spin up a local MCP test server and run the full client/block flow
|
||||
against it — no mocking, real HTTP requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
MOCK_USER_ID = "test-user-integration"
|
||||
|
||||
|
||||
class _MCPTestServer:
|
||||
"""
|
||||
Run an MCP test server in a background thread with its own event loop.
|
||||
This avoids event loop conflicts with pytest-asyncio.
|
||||
"""
|
||||
|
||||
def __init__(self, auth_token: str | None = None):
|
||||
self.auth_token = auth_token
|
||||
self.url: str = ""
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def _run(self):
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_until_complete(self._start())
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
async def _start(self):
|
||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
||||
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
if not self._started.wait(timeout=5):
|
||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
if self._loop and self._runner:
|
||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
||||
timeout=5
|
||||
)
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server():
|
||||
"""Start a local MCP test server in a background thread."""
|
||||
server = _MCPTestServer()
|
||||
server.start()
|
||||
yield server.url
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server_with_auth():
|
||||
"""Start a local MCP test server with auth in a background thread."""
|
||||
server = _MCPTestServer(auth_token="test-secret-token")
|
||||
server.start()
|
||||
yield server.url, "test-secret-token"
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _allow_localhost():
|
||||
"""
|
||||
Allow 127.0.0.1 through SSRF protection for integration tests.
|
||||
|
||||
The Requests class blocks private IPs by default. We patch the Requests
|
||||
constructor to always include 127.0.0.1 as a trusted origin so the local
|
||||
test server is reachable.
|
||||
"""
|
||||
from backend.util.request import Requests
|
||||
|
||||
original_init = Requests.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
trusted = list(kwargs.get("trusted_origins") or [])
|
||||
trusted.append("http://127.0.0.1")
|
||||
kwargs["trusted_origins"] = trusted
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
with patch.object(Requests, "__init__", patched_init):
|
||||
yield
|
||||
|
||||
|
||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
||||
"""Create an MCPClient for integration tests."""
|
||||
return MCPClient(url, auth_token=auth_token)
|
||||
|
||||
|
||||
# ── MCPClient integration tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClientIntegration:
|
||||
"""Test MCPClient against a real local MCP server."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_initialize(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
||||
assert "tools" in result["capabilities"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
||||
|
||||
# Check get_weather schema
|
||||
weather = next(t for t in tools if t.name == "get_weather")
|
||||
assert weather.description == "Get current weather for a city"
|
||||
assert "city" in weather.input_schema["properties"]
|
||||
assert weather.input_schema["required"] == ["city"]
|
||||
|
||||
# Check add_numbers schema
|
||||
add = next(t for t in tools if t.name == "add_numbers")
|
||||
assert "a" in add.input_schema["properties"]
|
||||
assert "b" in add.input_schema["properties"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_get_weather(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert data["city"] == "London"
|
||||
assert data["temperature"] == 22
|
||||
assert data["condition"] == "sunny"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_add_numbers(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
||||
|
||||
assert not result.is_error
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert data["result"] == 10
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_echo(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
||||
|
||||
assert not result.is_error
|
||||
assert result.content[0]["text"] == "Hello MCP!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_unknown_tool(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("nonexistent_tool", {})
|
||||
|
||||
assert result.is_error
|
||||
assert "Unknown tool" in result.content[0]["text"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_success(self, mcp_server_with_auth):
|
||||
url, token = mcp_server_with_auth
|
||||
client = _make_client(url, auth_token=token)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 3
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_failure(self, mcp_server_with_auth):
|
||||
url, _ = mcp_server_with_auth
|
||||
client = _make_client(url, auth_token="wrong-token")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await client.initialize()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auth_missing(self, mcp_server_with_auth):
|
||||
url, _ = mcp_server_with_auth
|
||||
client = _make_client(url)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await client.initialize()
|
||||
|
||||
|
||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPToolBlockIntegration:
|
||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_get_weather(self, mcp_server):
|
||||
"""Full flow: discover tools, select one, execute it."""
|
||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 3
|
||||
|
||||
# Step 2: User selects "get_weather" and we get its schema
|
||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
||||
|
||||
# Step 3: Execute the block — no credentials (public server)
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="get_weather",
|
||||
tool_input_schema=weather_tool.input_schema,
|
||||
tool_arguments={"city": "Paris"},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
result = outputs[0][1]
|
||||
assert result["city"] == "Paris"
|
||||
assert result["temperature"] == 22
|
||||
assert result["condition"] == "sunny"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_add_numbers(self, mcp_server):
|
||||
"""Full flow for add_numbers tool."""
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="add_numbers",
|
||||
tool_input_schema=add_tool.input_schema,
|
||||
tool_arguments={"a": 42, "b": 58},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1]["result"] == 100
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
||||
"""Verify plain text (non-JSON) responses work."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
||||
"""Calling an unknown tool should yield an error output."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="nonexistent_tool",
|
||||
tool_arguments={},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "error"
|
||||
assert "returned an error" in outputs[0][1]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
||||
"""Full flow with authentication via credentials kwarg."""
|
||||
url, token = mcp_server_with_auth
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=url,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "Authenticated!"},
|
||||
)
|
||||
|
||||
# Pass credentials via the standard kwarg (as the executor would)
|
||||
test_creds = OAuth2Credentials(
|
||||
id="test-cred",
|
||||
provider="mcp",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr(""),
|
||||
scopes=[],
|
||||
title="Test MCP credential",
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(
|
||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "Authenticated!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
||||
"""Block runs without auth when no credentials are provided."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "No auth needed"},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(
|
||||
input_data, user_id=MOCK_USER_ID, credentials=None
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "No auth needed"
|
||||
619
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
619
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
Tests for MCP client and MCPToolBlock.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSSEParsing:
|
||||
"""Tests for SSE (text/event-stream) response parsing."""
|
||||
|
||||
def test_parse_sse_simple(self):
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == {"tools": []}
|
||||
assert body["id"] == 1
|
||||
|
||||
def test_parse_sse_with_notifications(self):
|
||||
"""SSE streams can contain notifications (no id) before the response."""
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
||||
"\n"
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == {"ok": True}
|
||||
assert body["id"] == 2
|
||||
|
||||
def test_parse_sse_error_response(self):
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert "error" in body
|
||||
assert body["error"]["code"] == -32600
|
||||
|
||||
def test_parse_sse_no_data_raises(self):
|
||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||
MCPClient._parse_sse_response("event: message\n\n")
|
||||
|
||||
def test_parse_sse_empty_raises(self):
|
||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||
MCPClient._parse_sse_response("")
|
||||
|
||||
def test_parse_sse_ignores_non_data_lines(self):
|
||||
sse = (
|
||||
": comment line\n"
|
||||
"event: message\n"
|
||||
"id: 123\n"
|
||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "ok"
|
||||
|
||||
def test_parse_sse_uses_last_response(self):
|
||||
"""If multiple responses exist, use the last one."""
|
||||
sse = (
|
||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
||||
"\n"
|
||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "second"
|
||||
|
||||
|
||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClient:
|
||||
"""Tests for the MCP HTTP client."""
|
||||
|
||||
def test_build_headers_without_auth(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
headers = client._build_headers()
|
||||
assert "Authorization" not in headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
def test_build_headers_with_auth(self):
|
||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
||||
headers = client._build_headers()
|
||||
assert headers["Authorization"] == "Bearer my-token"
|
||||
|
||||
def test_build_jsonrpc_request(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req = client._build_jsonrpc_request("tools/list")
|
||||
assert req["jsonrpc"] == "2.0"
|
||||
assert req["method"] == "tools/list"
|
||||
assert "id" in req
|
||||
assert "params" not in req
|
||||
|
||||
def test_build_jsonrpc_request_with_params(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req = client._build_jsonrpc_request(
|
||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
||||
)
|
||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
||||
|
||||
def test_request_id_increments(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req1 = client._build_jsonrpc_request("tools/list")
|
||||
req2 = client._build_jsonrpc_request("tools/list")
|
||||
assert req2["id"] > req1["id"]
|
||||
|
||||
def test_server_url_trailing_slash_stripped(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp/")
|
||||
assert client.server_url == "https://mcp.example.com/mcp"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_send_request_success(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"result": {"tools": []},
|
||||
"id": 1,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||
result = await client._send_request("tools/list")
|
||||
assert result == {"tools": []}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_send_request_error(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
async def mock_send(*args, **kwargs):
|
||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
||||
|
||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
||||
await client._send_request("tools/list")
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "get_weather"
|
||||
assert tools[0].description == "Get current weather for a city"
|
||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
||||
assert tools[1].name == "search"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools_empty(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert tools == []
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_list_tools_none_result(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value=None):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert tools == []
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_success(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
||||
],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_error(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"content": [{"type": "text", "text": "City not found"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
result = await client.call_tool("get_weather", {"city": "???"})
|
||||
|
||||
assert result.is_error
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_tool_none_result(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value=None):
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert result.is_error
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_initialize(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
||||
patch.object(client, "_send_notification") as mock_notif,
|
||||
):
|
||||
result = await client.initialize()
|
||||
|
||||
mock_req.assert_called_once()
|
||||
mock_notif.assert_called_once_with("notifications/initialized")
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
|
||||
|
||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||
|
||||
MOCK_USER_ID = "test-user-123"
|
||||
|
||||
|
||||
class TestMCPToolBlock:
|
||||
"""Tests for the MCPToolBlock."""
|
||||
|
||||
def test_block_instantiation(self):
|
||||
block = MCPToolBlock()
|
||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
assert block.name == "MCPToolBlock"
|
||||
|
||||
def test_input_schema_has_required_fields(self):
|
||||
block = MCPToolBlock()
|
||||
schema = block.input_schema.jsonschema()
|
||||
props = schema.get("properties", {})
|
||||
assert "server_url" in props
|
||||
assert "selected_tool" in props
|
||||
assert "tool_arguments" in props
|
||||
assert "credentials" in props
|
||||
|
||||
def test_output_schema(self):
|
||||
block = MCPToolBlock()
|
||||
schema = block.output_schema.jsonschema()
|
||||
props = schema.get("properties", {})
|
||||
assert "result" in props
|
||||
assert "error" in props
|
||||
|
||||
def test_get_input_schema_with_tool_schema(self):
|
||||
tool_schema = {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
}
|
||||
data = {"tool_input_schema": tool_schema}
|
||||
result = MCPToolBlock.Input.get_input_schema(data)
|
||||
assert result == tool_schema
|
||||
|
||||
def test_get_input_schema_without_tool_schema(self):
|
||||
result = MCPToolBlock.Input.get_input_schema({})
|
||||
assert result == {}
|
||||
|
||||
def test_get_input_defaults(self):
|
||||
data = {"tool_arguments": {"city": "London"}}
|
||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
||||
assert result == {"city": "London"}
|
||||
|
||||
def test_get_missing_input(self):
|
||||
data = {
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"units": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "units"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == {"units"}
|
||||
|
||||
def test_get_missing_input_all_present(self):
|
||||
data = {
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == set()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_with_mock(self):
|
||||
"""Test the block using the built-in test infrastructure."""
|
||||
block = MCPToolBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_missing_server_url(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="",
|
||||
selected_tool="test",
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [("error", "MCP server URL is required")]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_missing_tool(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="",
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [
|
||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_success(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="get_weather",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
tool_arguments={"city": "London"},
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
return {"temp": 20, "city": "London"}
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_mcp_error(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="bad_tool",
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
raise MCPClientError("Tool not found")
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs[0][0] == "error"
|
||||
assert "Tool not found" in outputs[0][1]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_parses_json_text(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": '{"temp": 20}'},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == {"temp": 20}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_plain_text(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello, world!"},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == "Hello, world!"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_multiple_content(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": "Part 1"},
|
||||
{"type": "text", "text": '{"part": 2}'},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == ["Part 1", {"part": 2}]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_error_result(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[{"type": "text", "text": "Something went wrong"}],
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
with pytest.raises(MCPClientError, match="returned an error"):
|
||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_call_mcp_tool_image_content(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{
|
||||
"type": "image",
|
||||
"data": "base64data==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"type": "image",
|
||||
"data": "base64data==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_with_credentials(self):
|
||||
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
)
|
||||
|
||||
captured_tokens: list[str | None] = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
test_creds = OAuth2Credentials(
|
||||
id="cred-123",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("resolved-token"),
|
||||
refresh_token=SecretStr(""),
|
||||
scopes=[],
|
||||
title="Test MCP credential",
|
||||
)
|
||||
|
||||
async for _ in block.run(
|
||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
||||
):
|
||||
pass
|
||||
|
||||
assert captured_tokens == ["resolved-token"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_without_credentials(self):
|
||||
"""Verify the block works without credentials (public server)."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
)
|
||||
|
||||
captured_tokens: list[str | None] = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert captured_tokens == [None]
|
||||
assert outputs == [("result", "ok")]
|
||||
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Tests for MCP OAuth handler.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
|
||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.ok = 200 <= status < 300
|
||||
resp.json.return_value = json_data
|
||||
return resp
|
||||
|
||||
|
||||
class TestMCPOAuthHandler:
|
||||
"""Tests for the MCPOAuthHandler."""
|
||||
|
||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
||||
defaults = {
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"redirect_uri": "https://app.example.com/callback",
|
||||
"authorize_url": "https://auth.example.com/authorize",
|
||||
"token_url": "https://auth.example.com/token",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MCPOAuthHandler(**defaults)
|
||||
|
||||
def test_get_login_url_basic(self):
|
||||
handler = self._make_handler()
|
||||
url = handler.get_login_url(
|
||||
scopes=["read", "write"],
|
||||
state="random-state-token",
|
||||
code_challenge="S256-challenge-value",
|
||||
)
|
||||
|
||||
assert "https://auth.example.com/authorize?" in url
|
||||
assert "response_type=code" in url
|
||||
assert "client_id=test-client-id" in url
|
||||
assert "state=random-state-token" in url
|
||||
assert "code_challenge=S256-challenge-value" in url
|
||||
assert "code_challenge_method=S256" in url
|
||||
assert "scope=read+write" in url
|
||||
|
||||
def test_get_login_url_with_resource(self):
|
||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
||||
url = handler.get_login_url(
|
||||
scopes=[], state="state", code_challenge="challenge"
|
||||
)
|
||||
|
||||
assert "resource=https" in url
|
||||
|
||||
def test_get_login_url_without_pkce(self):
|
||||
handler = self._make_handler()
|
||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
||||
|
||||
assert "code_challenge" not in url
|
||||
assert "code_challenge_method" not in url
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_exchange_code_for_tokens(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
resp = _mock_response(
|
||||
{
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
creds = await handler.exchange_code_for_tokens(
|
||||
code="auth-code",
|
||||
scopes=["read"],
|
||||
code_verifier="pkce-verifier",
|
||||
)
|
||||
|
||||
assert isinstance(creds, OAuth2Credentials)
|
||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
||||
assert creds.refresh_token is not None
|
||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
||||
assert creds.scopes == ["read"]
|
||||
assert creds.access_token_expires_at is not None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_tokens(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
existing_creds = OAuth2Credentials(
|
||||
id="existing-id",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("old-token"),
|
||||
refresh_token=SecretStr("old-refresh"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
|
||||
resp = _mock_response(
|
||||
{
|
||||
"access_token": "refreshed-token",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
refreshed = await handler._refresh_tokens(existing_creds)
|
||||
|
||||
assert refreshed.id == "existing-id"
|
||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
||||
assert refreshed.refresh_token is not None
|
||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_tokens_no_refresh_token(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No refresh token"):
|
||||
await handler._refresh_tokens(creds)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_revoke_tokens_no_url(self):
|
||||
handler = self._make_handler(revoke_url=None)
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
title="test",
|
||||
)
|
||||
|
||||
result = await handler.revoke_tokens(creds)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_revoke_tokens_with_url(self):
|
||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
title="test",
|
||||
)
|
||||
|
||||
resp = _mock_response({}, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
result = await handler.revoke_tokens(creds)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestMCPClientDiscovery:
|
||||
"""Tests for MCPClient OAuth metadata discovery."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_auth_found(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
metadata = {
|
||||
"authorization_servers": ["https://auth.example.com"],
|
||||
"resource": "https://mcp.example.com/mcp",
|
||||
}
|
||||
|
||||
resp = _mock_response(metadata, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth()
|
||||
|
||||
assert result is not None
|
||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_auth_not_found(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
resp = _mock_response({}, status=404)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_auth_server_metadata(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
server_metadata = {
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"registration_endpoint": "https://auth.example.com/register",
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
|
||||
resp = _mock_response(server_metadata, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth_server_metadata(
|
||||
"https://auth.example.com"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
||||
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Minimal MCP server for integration testing.
|
||||
|
||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
||||
with a few sample tools. Runs on localhost with a random available port.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sample tools this test server exposes
|
||||
TEST_TOOLS = [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "add_numbers",
|
||||
"description": "Add two numbers together",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echo back the input message",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message to echo"},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _handle_initialize(params: dict) -> dict:
|
||||
return {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {"listChanged": False}},
|
||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
||||
}
|
||||
|
||||
|
||||
def _handle_tools_list(params: dict) -> dict:
|
||||
return {"tools": TEST_TOOLS}
|
||||
|
||||
|
||||
def _handle_tools_call(params: dict) -> dict:
|
||||
tool_name = params.get("name", "")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if tool_name == "get_weather":
|
||||
city = arguments.get("city", "Unknown")
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
elif tool_name == "add_numbers":
|
||||
a = arguments.get("a", 0)
|
||||
b = arguments.get("b", 0)
|
||||
return {
|
||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
||||
}
|
||||
|
||||
elif tool_name == "echo":
|
||||
message = arguments.get("message", "")
|
||||
return {
|
||||
"content": [{"type": "text", "text": message}],
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
HANDLERS = {
|
||||
"initialize": _handle_initialize,
|
||||
"tools/list": _handle_tools_list,
|
||||
"tools/call": _handle_tools_call,
|
||||
}
|
||||
|
||||
|
||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
||||
# Check auth if configured
|
||||
expected_token = request.app.get("auth_token")
|
||||
if expected_token:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header != f"Bearer {expected_token}":
|
||||
return web.json_response(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32001, "message": "Unauthorized"},
|
||||
"id": None,
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Handle notifications (no id field) — just acknowledge
|
||||
if "id" not in body:
|
||||
return web.Response(status=202)
|
||||
|
||||
method = body.get("method", "")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
handler = HANDLERS.get(method)
|
||||
if not handler:
|
||||
return web.json_response(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Method not found: {method}",
|
||||
},
|
||||
"id": request_id,
|
||||
}
|
||||
)
|
||||
|
||||
result = handler(params)
|
||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
||||
|
||||
|
||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
||||
"""Create an aiohttp app that acts as an MCP server."""
|
||||
app = web.Application()
|
||||
app.router.add_post("/mcp", handle_mcp_request)
|
||||
if auth_token:
|
||||
app["auth_token"] = auth_token
|
||||
return app
|
||||
182
autogpt_platform/backend/backend/blocks/telegram/_api.py
Normal file
182
autogpt_platform/backend/backend/blocks/telegram/_api.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Telegram Bot API helper functions.
|
||||
|
||||
Provides utilities for making authenticated requests to the Telegram Bot API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TELEGRAM_API_BASE = "https://api.telegram.org"
|
||||
|
||||
|
||||
class TelegramMessageResult(BaseModel, extra="allow"):
|
||||
"""Result from Telegram send/edit message API calls."""
|
||||
|
||||
message_id: int = 0
|
||||
chat: dict[str, Any] = {}
|
||||
date: int = 0
|
||||
text: str = ""
|
||||
|
||||
|
||||
class TelegramFileResult(BaseModel, extra="allow"):
|
||||
"""Result from Telegram getFile API call."""
|
||||
|
||||
file_id: str = ""
|
||||
file_unique_id: str = ""
|
||||
file_size: int = 0
|
||||
file_path: str = ""
|
||||
|
||||
|
||||
class TelegramAPIException(ValueError):
|
||||
"""Exception raised for Telegram API errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: int = 0):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
def get_bot_api_url(bot_token: str, method: str) -> str:
|
||||
"""Construct Telegram Bot API URL for a method."""
|
||||
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
|
||||
|
||||
|
||||
def get_file_url(bot_token: str, file_path: str) -> str:
|
||||
"""Construct Telegram file download URL."""
|
||||
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
|
||||
|
||||
|
||||
async def call_telegram_api(
|
||||
credentials: APIKeyCredentials,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> TelegramMessageResult:
|
||||
"""
|
||||
Make a request to the Telegram Bot API.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
method: API method name (e.g., "sendMessage", "getFile")
|
||||
data: Request parameters
|
||||
|
||||
Returns:
|
||||
API response result
|
||||
|
||||
Raises:
|
||||
TelegramAPIException: If the API returns an error
|
||||
"""
|
||||
token = credentials.api_key.get_secret_value()
|
||||
url = get_bot_api_url(token, method)
|
||||
|
||||
response = await Requests().post(url, json=data or {})
|
||||
result = response.json()
|
||||
|
||||
if not result.get("ok"):
|
||||
error_code = result.get("error_code", 0)
|
||||
description = result.get("description", "Unknown error")
|
||||
raise TelegramAPIException(description, error_code)
|
||||
|
||||
return TelegramMessageResult(**result.get("result", {}))
|
||||
|
||||
|
||||
async def call_telegram_api_with_file(
|
||||
credentials: APIKeyCredentials,
|
||||
method: str,
|
||||
file_field: str,
|
||||
file_data: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> TelegramMessageResult:
|
||||
"""
|
||||
Make a multipart/form-data request to the Telegram Bot API with a file upload.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
method: API method name (e.g., "sendPhoto", "sendVoice")
|
||||
file_field: Form field name for the file (e.g., "photo", "voice")
|
||||
file_data: Raw file bytes
|
||||
filename: Filename for the upload
|
||||
content_type: MIME type of the file
|
||||
data: Additional form parameters
|
||||
|
||||
Returns:
|
||||
API response result
|
||||
|
||||
Raises:
|
||||
TelegramAPIException: If the API returns an error
|
||||
"""
|
||||
token = credentials.api_key.get_secret_value()
|
||||
url = get_bot_api_url(token, method)
|
||||
|
||||
files = [(file_field, (filename, BytesIO(file_data), content_type))]
|
||||
|
||||
response = await Requests().post(url, files=files, data=data or {})
|
||||
result = response.json()
|
||||
|
||||
if not result.get("ok"):
|
||||
error_code = result.get("error_code", 0)
|
||||
description = result.get("description", "Unknown error")
|
||||
raise TelegramAPIException(description, error_code)
|
||||
|
||||
return TelegramMessageResult(**result.get("result", {}))
|
||||
|
||||
|
||||
async def get_file_info(
|
||||
credentials: APIKeyCredentials, file_id: str
|
||||
) -> TelegramFileResult:
|
||||
"""
|
||||
Get file information from Telegram.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
file_id: Telegram file_id from message
|
||||
|
||||
Returns:
|
||||
File info dict containing file_id, file_unique_id, file_size, file_path
|
||||
"""
|
||||
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
|
||||
return TelegramFileResult(**result.model_dump())
|
||||
|
||||
|
||||
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
|
||||
"""
|
||||
Get the download URL for a Telegram file.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
file_id: Telegram file_id from message
|
||||
|
||||
Returns:
|
||||
Full download URL
|
||||
"""
|
||||
token = credentials.api_key.get_secret_value()
|
||||
result = await get_file_info(credentials, file_id)
|
||||
file_path = result.file_path
|
||||
if not file_path:
|
||||
raise TelegramAPIException("No file_path returned from getFile")
|
||||
return get_file_url(token, file_path)
|
||||
|
||||
|
||||
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
|
||||
"""
|
||||
Download a file from Telegram servers.
|
||||
|
||||
Args:
|
||||
credentials: Bot token credentials
|
||||
file_id: Telegram file_id
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
url = await get_file_download_url(credentials, file_id)
|
||||
response = await Requests().get(url)
|
||||
return response.content
|
||||
43
autogpt_platform/backend/backend/blocks/telegram/_auth.py
Normal file
43
autogpt_platform/backend/backend/blocks/telegram/_auth.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Telegram Bot credentials handling.
|
||||
|
||||
Telegram bots use an API key (bot token) obtained from @BotFather.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Bot token credentials (API key style)
|
||||
TelegramCredentials = APIKeyCredentials
|
||||
TelegramCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.TELEGRAM], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def TelegramCredentialsField() -> TelegramCredentialsInput:
|
||||
"""Creates a Telegram bot token credentials field."""
|
||||
return CredentialsField(
|
||||
description="Telegram Bot API token from @BotFather. "
|
||||
"Create a bot at https://t.me/BotFather to get your token."
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for unit tests
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="telegram",
|
||||
api_key=SecretStr("test_telegram_bot_token"),
|
||||
title="Mock Telegram Bot Token",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
1254
autogpt_platform/backend/backend/blocks/telegram/blocks.py
Normal file
1254
autogpt_platform/backend/backend/blocks/telegram/blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
377
autogpt_platform/backend/backend/blocks/telegram/triggers.py
Normal file
377
autogpt_platform/backend/backend/blocks/telegram/triggers.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Telegram trigger blocks for receiving messages via webhooks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.telegram import TelegramWebhookType
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TelegramCredentialsField,
|
||||
TelegramCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Example payload for testing
|
||||
EXAMPLE_MESSAGE_PAYLOAD = {
|
||||
"update_id": 123456789,
|
||||
"message": {
|
||||
"message_id": 1,
|
||||
"from": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"username": "johndoe",
|
||||
"language_code": "en",
|
||||
},
|
||||
"chat": {
|
||||
"id": 12345678,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"username": "johndoe",
|
||||
"type": "private",
|
||||
},
|
||||
"date": 1234567890,
|
||||
"text": "Hello, bot!",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TelegramTriggerBase:
|
||||
"""Base class for Telegram trigger blocks."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: TelegramCredentialsInput = TelegramCredentialsField()
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
|
||||
|
||||
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
|
||||
"""
|
||||
Triggers when a message is received or edited in your Telegram bot.
|
||||
|
||||
Supports text, photos, voice messages, audio files, documents, and videos.
|
||||
Connect the outputs to other blocks to process messages and send responses.
|
||||
"""
|
||||
|
||||
class Input(TelegramTriggerBase.Input):
|
||||
class EventsFilter(BaseModel):
|
||||
"""Filter for message types to receive."""
|
||||
|
||||
text: bool = True
|
||||
photo: bool = False
|
||||
voice: bool = False
|
||||
audio: bool = False
|
||||
document: bool = False
|
||||
video: bool = False
|
||||
edited_message: bool = False
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Message Types", description="Types of messages to receive"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload from Telegram"
|
||||
)
|
||||
chat_id: int = SchemaField(
|
||||
description="The chat ID where the message was received. "
|
||||
"Use this to send replies."
|
||||
)
|
||||
message_id: int = SchemaField(description="The unique message ID")
|
||||
user_id: int = SchemaField(description="The user ID who sent the message")
|
||||
username: str = SchemaField(description="Username of the sender (may be empty)")
|
||||
first_name: str = SchemaField(description="First name of the sender")
|
||||
event: str = SchemaField(
|
||||
description="The message type (text, photo, voice, audio, etc.)"
|
||||
)
|
||||
text: str = SchemaField(
|
||||
description="Text content of the message (for text messages)"
|
||||
)
|
||||
photo_file_id: str = SchemaField(
|
||||
description="File ID of the photo (for photo messages). "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
voice_file_id: str = SchemaField(
|
||||
description="File ID of the voice message (for voice messages). "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
audio_file_id: str = SchemaField(
|
||||
description="File ID of the audio file (for audio messages). "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
file_id: str = SchemaField(
|
||||
description="File ID for document/video messages. "
|
||||
"Use GetTelegramFileBlock to download."
|
||||
)
|
||||
file_name: str = SchemaField(
|
||||
description="Original filename (for document/audio messages)"
|
||||
)
|
||||
caption: str = SchemaField(description="Caption for media messages")
|
||||
is_edited: bool = SchemaField(
|
||||
description="Whether this is an edit of a previously sent message"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
|
||||
description="Triggers when a message is received or edited in your Telegram bot. "
|
||||
"Supports text, photos, voice messages, audio files, documents, and videos.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TelegramMessageTriggerBlock.Input,
|
||||
output_schema=TelegramMessageTriggerBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.TELEGRAM,
|
||||
webhook_type=TelegramWebhookType.BOT,
|
||||
resource_format="bot",
|
||||
event_filter_input="events",
|
||||
event_format="message.{event}",
|
||||
),
|
||||
test_input={
|
||||
"events": {"text": True, "photo": True},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"payload": EXAMPLE_MESSAGE_PAYLOAD,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("payload", EXAMPLE_MESSAGE_PAYLOAD),
|
||||
("chat_id", 12345678),
|
||||
("message_id", 1),
|
||||
("user_id", 12345678),
|
||||
("username", "johndoe"),
|
||||
("first_name", "John"),
|
||||
("is_edited", False),
|
||||
("event", "text"),
|
||||
("text", "Hello, bot!"),
|
||||
("photo_file_id", ""),
|
||||
("voice_file_id", ""),
|
||||
("audio_file_id", ""),
|
||||
("file_id", ""),
|
||||
("file_name", ""),
|
||||
("caption", ""),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
is_edited = "edited_message" in payload
|
||||
message = payload.get("message") or payload.get("edited_message", {})
|
||||
|
||||
# Extract common fields
|
||||
chat = message.get("chat", {})
|
||||
sender = message.get("from", {})
|
||||
|
||||
yield "payload", payload
|
||||
yield "chat_id", chat.get("id", 0)
|
||||
yield "message_id", message.get("message_id", 0)
|
||||
yield "user_id", sender.get("id", 0)
|
||||
yield "username", sender.get("username", "")
|
||||
yield "first_name", sender.get("first_name", "")
|
||||
yield "is_edited", is_edited
|
||||
|
||||
# For edited messages, yield event as "edited_message" and extract
|
||||
# all content fields from the edited message body
|
||||
if is_edited:
|
||||
yield "event", "edited_message"
|
||||
yield "text", message.get("text", "")
|
||||
photos = message.get("photo", [])
|
||||
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
|
||||
voice = message.get("voice", {})
|
||||
yield "voice_file_id", voice.get("file_id", "")
|
||||
audio = message.get("audio", {})
|
||||
yield "audio_file_id", audio.get("file_id", "")
|
||||
document = message.get("document", {})
|
||||
video = message.get("video", {})
|
||||
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
|
||||
yield "file_name", (
|
||||
document.get("file_name", "") or audio.get("file_name", "")
|
||||
)
|
||||
yield "caption", message.get("caption", "")
|
||||
# Determine message type and extract content
|
||||
elif "text" in message:
|
||||
yield "event", "text"
|
||||
yield "text", message.get("text", "")
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", ""
|
||||
elif "photo" in message:
|
||||
# Get the largest photo (last in array)
|
||||
photos = message.get("photo", [])
|
||||
photo_fid = photos[-1].get("file_id", "") if photos else ""
|
||||
yield "event", "photo"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", photo_fid
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "voice" in message:
|
||||
voice = message.get("voice", {})
|
||||
yield "event", "voice"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", voice.get("file_id", "")
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "audio" in message:
|
||||
audio = message.get("audio", {})
|
||||
yield "event", "audio"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", audio.get("file_id", "")
|
||||
yield "file_id", ""
|
||||
yield "file_name", audio.get("file_name", "")
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "document" in message:
|
||||
document = message.get("document", {})
|
||||
yield "event", "document"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", document.get("file_id", "")
|
||||
yield "file_name", document.get("file_name", "")
|
||||
yield "caption", message.get("caption", "")
|
||||
elif "video" in message:
|
||||
video = message.get("video", {})
|
||||
yield "event", "video"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", video.get("file_id", "")
|
||||
yield "file_name", video.get("file_name", "")
|
||||
yield "caption", message.get("caption", "")
|
||||
else:
|
||||
yield "event", "other"
|
||||
yield "text", ""
|
||||
yield "photo_file_id", ""
|
||||
yield "voice_file_id", ""
|
||||
yield "audio_file_id", ""
|
||||
yield "file_id", ""
|
||||
yield "file_name", ""
|
||||
yield "caption", ""
|
||||
|
||||
|
||||
# Example payload for reaction trigger testing
|
||||
EXAMPLE_REACTION_PAYLOAD = {
|
||||
"update_id": 123456790,
|
||||
"message_reaction": {
|
||||
"chat": {
|
||||
"id": 12345678,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"username": "johndoe",
|
||||
"type": "private",
|
||||
},
|
||||
"message_id": 42,
|
||||
"user": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "John",
|
||||
"username": "johndoe",
|
||||
},
|
||||
"date": 1234567890,
|
||||
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
|
||||
"old_reaction": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
|
||||
"""
|
||||
Triggers when a reaction to a message is changed.
|
||||
|
||||
Works automatically in private chats. In group chats, the bot must be
|
||||
an administrator to receive reaction updates.
|
||||
"""
|
||||
|
||||
class Input(TelegramTriggerBase.Input):
|
||||
pass
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload from Telegram"
|
||||
)
|
||||
chat_id: int = SchemaField(
|
||||
description="The chat ID where the reaction occurred"
|
||||
)
|
||||
message_id: int = SchemaField(description="The message ID that was reacted to")
|
||||
user_id: int = SchemaField(description="The user ID who changed the reaction")
|
||||
username: str = SchemaField(description="Username of the user (may be empty)")
|
||||
new_reactions: list = SchemaField(
|
||||
description="List of new reactions on the message"
|
||||
)
|
||||
old_reactions: list = SchemaField(
|
||||
description="List of previous reactions on the message"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82525328-9368-4966-8f0c-cd78e80181fd",
|
||||
description="Triggers when a reaction to a message is changed. "
|
||||
"Works in private chats automatically. "
|
||||
"In groups, the bot must be an administrator.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TelegramMessageReactionTriggerBlock.Input,
|
||||
output_schema=TelegramMessageReactionTriggerBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.TELEGRAM,
|
||||
webhook_type=TelegramWebhookType.BOT,
|
||||
resource_format="bot",
|
||||
event_filter_input="",
|
||||
event_format="message_reaction",
|
||||
),
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"payload": EXAMPLE_REACTION_PAYLOAD,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("payload", EXAMPLE_REACTION_PAYLOAD),
|
||||
("chat_id", 12345678),
|
||||
("message_id", 42),
|
||||
("user_id", 12345678),
|
||||
("username", "johndoe"),
|
||||
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
|
||||
("old_reactions", []),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
reaction = payload.get("message_reaction", {})
|
||||
|
||||
chat = reaction.get("chat", {})
|
||||
user = reaction.get("user", {})
|
||||
|
||||
yield "payload", payload
|
||||
yield "chat_id", chat.get("id", 0)
|
||||
yield "message_id", reaction.get("message_id", 0)
|
||||
yield "user_id", user.get("id", 0)
|
||||
yield "username", user.get("username", "")
|
||||
yield "new_reactions", reaction.get("new_reaction", [])
|
||||
yield "old_reactions", reaction.get("old_reaction", [])
|
||||
@@ -34,10 +34,12 @@ def main(output: Path, pretty: bool):
|
||||
"""Generate and output the OpenAPI JSON specification."""
|
||||
openapi_schema = get_openapi_schema()
|
||||
|
||||
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
|
||||
json_output = json.dumps(
|
||||
openapi_schema, indent=2 if pretty else None, ensure_ascii=False
|
||||
)
|
||||
|
||||
if output:
|
||||
output.write_text(json_output)
|
||||
output.write_text(json_output, encoding="utf-8")
|
||||
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
||||
click.echo(f"\n{json_output[:500]} ...")
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -27,6 +28,54 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
|
||||
1
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
1
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -27,63 +27,38 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_lock_ttl: int = Field(
|
||||
default=120,
|
||||
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
|
||||
"reconnection after refresh/crash without long waits.",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
session_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
description="Prefix for session metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
turn_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
description="Prefix for turn message stream keys",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
@@ -93,6 +68,31 @@ class ChatConfig(BaseSettings):
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
description="Model for the Claude Agent SDK path. If None, derives from "
|
||||
"the `model` field by stripping the OpenRouter provider prefix.",
|
||||
)
|
||||
claude_agent_max_buffer_size: int = Field(
|
||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
||||
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
|
||||
"Increase if tool outputs exceed the limit.",
|
||||
)
|
||||
claude_agent_max_subtasks: int = Field(
|
||||
default=10,
|
||||
description="Max number of concurrent sub-agent Tasks the SDK can run per session.",
|
||||
)
|
||||
claude_agent_use_resume: bool = Field(
|
||||
default=True,
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
@@ -130,13 +130,16 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
@@ -3,8 +3,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
@@ -14,29 +15,27 @@ from prisma.types import (
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": True},
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> PrismaChatSession:
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
@@ -45,7 +44,8 @@ async def create_chat_session(
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(data=data)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
@@ -56,7 +56,7 @@ async def update_chat_session(
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
) -> ChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
@@ -76,12 +76,9 @@ async def update_chat_session(
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
return ChatSession.from_db(session) if session else None
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
@@ -94,12 +91,11 @@ async def add_chat_message(
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
) -> ChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
data: ChatMessageCreateInput = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
@@ -127,81 +123,117 @@ async def add_chat_message(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
PrismaChatMessage.prisma().create(data=data),
|
||||
)
|
||||
return message
|
||||
return ChatMessage.from_db(message)
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
) -> int:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||
streaming loop and long-running callback race), queries the latest
|
||||
sequence and retries with the correct offset. This avoids unnecessary
|
||||
upserts and DB queries in the common case (no collision).
|
||||
|
||||
Returns:
|
||||
Next sequence number for the next message to be inserted. This equals
|
||||
start_sequence + len(messages) and allows callers to update their
|
||||
counters even when collision detection adjusts start_sequence.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
# No messages to add - return current count
|
||||
return start_sequence
|
||||
|
||||
created_messages = []
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Single timestamp for all messages and session update
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
async with db.transaction() as tx:
|
||||
# Build all message data
|
||||
messages_data = []
|
||||
for i, msg in enumerate(messages):
|
||||
# Build ChatMessageCreateInput with only non-None values
|
||||
# (Prisma TypedDict rejects optional fields set to None)
|
||||
# Note: create_many doesn't support nested creates, use sessionId directly
|
||||
data: ChatMessageCreateInput = {
|
||||
"sessionId": session_id,
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
"createdAt": now,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
messages_data.append(data)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
# Run create_many and session update in parallel within transaction
|
||||
# Both use the same timestamp for consistency
|
||||
await asyncio.gather(
|
||||
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
|
||||
PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": now},
|
||||
),
|
||||
)
|
||||
|
||||
return created_messages
|
||||
# Return next sequence number for counter sync
|
||||
return start_sequence + len(messages)
|
||||
|
||||
except UniqueViolationError:
|
||||
if attempt < max_retries - 1:
|
||||
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
||||
logger.info(
|
||||
f"Collision detected for session {session_id} at sequence "
|
||||
f"{start_sequence}, querying DB for latest sequence"
|
||||
)
|
||||
start_sequence = await get_next_sequence(session_id)
|
||||
logger.info(
|
||||
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded - propagate error
|
||||
raise
|
||||
|
||||
# Should never reach here due to raise in exception handler
|
||||
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[PrismaChatSession]:
|
||||
) -> list[ChatSessionInfo]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
@@ -240,10 +272,20 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return False
|
||||
|
||||
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
async def get_next_sequence(session_id: str) -> int:
|
||||
"""Get the next sequence number for a new message in this session.
|
||||
|
||||
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
||||
More robust than COUNT(*) because it's immune to deleted messages.
|
||||
|
||||
Optimized to select only the sequence column using raw SQL.
|
||||
The unique index on (sessionId, sequence) makes this query fast.
|
||||
"""
|
||||
results = await db.query_raw_with_schema(
|
||||
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
|
||||
session_id,
|
||||
)
|
||||
return 0 if not results else results[0]["sequence"] + 1
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Entry point for running the CoPilot Executor service.
|
||||
|
||||
Usage:
|
||||
python -m backend.copilot.executor
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .manager import CoPilotExecutor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Executor service."""
|
||||
run_processes(CoPilotExecutor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
526
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
526
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||
|
||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_turn, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
settings = Settings()
|
||||
|
||||
# Prometheus metrics
|
||||
active_tasks_gauge = Gauge(
|
||||
"copilot_executor_active_tasks",
|
||||
"Number of active CoPilot tasks",
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"copilot_executor_pool_size",
|
||||
"Maximum number of CoPilot executor workers",
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"copilot_executor_utilization_ratio",
|
||||
"Ratio of active tasks to pool size",
|
||||
)
|
||||
|
||||
|
||||
class CoPilotExecutor(AppProcess):
|
||||
"""CoPilot Executor service for processing chat generation tasks.
|
||||
|
||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||
for reliable message handling and graceful shutdown.
|
||||
|
||||
Key features:
|
||||
- RabbitMQ-based task distribution with manual acknowledgment
|
||||
- Thread pool executor for concurrent task processing
|
||||
- Cluster lock for duplicate prevention across pods
|
||||
- Graceful shutdown with timeout for in-flight tasks
|
||||
- FANOUT exchange for cancellation broadcast
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_copilot_workers
|
||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
self._active_tasks_lock = threading.Lock()
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
def run(self):
|
||||
"""Main service loop - consume from RabbitMQ."""
|
||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_metrics()
|
||||
start_http_server(settings.config.copilot_executor_port)
|
||||
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
last_refresh = start_time
|
||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||
|
||||
while (
|
||||
self.active_tasks
|
||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
):
|
||||
self._cleanup_completed_tasks()
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
for lock in list(self._task_locks.values()):
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Clean up worker threads (closes per-loop workspace storage sessions)
|
||||
if self._executor:
|
||||
from .processor import cleanup_worker
|
||||
|
||||
logger.info(f"[cleanup {pid}] Cleaning up workers...")
|
||||
futures = []
|
||||
for _ in range(self._executor._max_workers):
|
||||
futures.append(self._executor.submit(cleanup_worker))
|
||||
for f in futures:
|
||||
try:
|
||||
f.result(timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
|
||||
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for session_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_cancel(self):
|
||||
"""Consume cancellation messages from FANOUT exchange."""
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.cancel_client.is_ready:
|
||||
self.cancel_client.disconnect()
|
||||
self.cancel_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.cancel_client.disconnect()
|
||||
return
|
||||
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info("Starting to consume cancel messages...")
|
||||
cancel_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||
logger.info("Cancel message consumer stopped gracefully")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_run(self):
|
||||
"""Consume run messages from DIRECT exchange."""
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.run_client.is_ready:
|
||||
self.run_client.disconnect()
|
||||
self.run_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.run_client.disconnect()
|
||||
return
|
||||
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
run_channel.basic_consume(
|
||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
consumer_tag="copilot_execution_consumer",
|
||||
)
|
||||
logger.info("Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||
logger.info("Run message consumer stopped gracefully")
|
||||
|
||||
# ============ Message Handlers ============ #
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
_method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
session_id = request.session_id
|
||||
if not session_id:
|
||||
logger.warning("Cancel message missing 'session_id'")
|
||||
return
|
||||
if session_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {session_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[session_id]
|
||||
logger.info(f"Received cancel for {session_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {session_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle run message from DIRECT exchange."""
|
||||
delivery_tag = method.delivery_tag
|
||||
# Capture the channel used at message delivery time to ensure we ack
|
||||
# on the correct channel. Delivery tags are channel-scoped and become
|
||||
# invalid if the channel is recreated after reconnection.
|
||||
delivery_channel = _channel
|
||||
|
||||
def ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message.
|
||||
|
||||
Uses the channel from the original message delivery. If the channel
|
||||
is no longer open (e.g., after reconnection), logs a warning and
|
||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||
"""
|
||||
try:
|
||||
if not delivery_channel.is_open:
|
||||
logger.warning(
|
||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||
"Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
return
|
||||
|
||||
if reject:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_nack(
|
||||
delivery_tag, requeue=requeue
|
||||
)
|
||||
)
|
||||
else:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||
logger.warning(
|
||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors might be transient, but log and skip to avoid blocking
|
||||
logger.error(
|
||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||
)
|
||||
|
||||
# Check if we're shutting down
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Rejecting new task during shutdown")
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more tasks
|
||||
self._cleanup_completed_tasks()
|
||||
if len(self.active_tasks) >= self.pool_size:
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
session_id = entry.session_id
|
||||
|
||||
# Check for local duplicate - session is already running on this executor
|
||||
if session_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Session {session_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:session:{session_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"Session {session_id} already running on pod {current_owner}"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {session_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[session_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {session_id}, "
|
||||
f"executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_turn, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[session_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {session_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if session_id in self._task_locks:
|
||||
del self._task_locks[session_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {session_id}")
|
||||
error_msg = None
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
error_msg = str(exec_error) or type(exec_error).__name__
|
||||
logger.error(f"Execution for {session_id} failed: {error_msg}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Run completion callback cancelled for {session_id}")
|
||||
except BaseException as e:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.exception(f"Error in run completion callback: {error_msg}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if session_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {session_id}")
|
||||
self._task_locks[session_id].release()
|
||||
del self._task_locks[session_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
|
||||
# ============ Helper Methods ============ #
|
||||
|
||||
def _cleanup_completed_tasks(self) -> list[str]:
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
with self._active_tasks_lock:
|
||||
for session_id, (future, _) in list(self.active_tasks.items()):
|
||||
if future.done():
|
||||
completed_tasks.append(session_id)
|
||||
self.active_tasks.pop(session_id, None)
|
||||
logger.info(f"Cleaned up completed session {session_id}")
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
def _update_metrics(self):
|
||||
"""Update Prometheus metrics."""
|
||||
active_count = len(self.active_tasks)
|
||||
active_tasks_gauge.set(active_count)
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(
|
||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||
)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
"""Stop a message consumer thread."""
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.error(
|
||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
self._stop_consuming = threading.Event()
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._run_client
|
||||
276
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
276
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot session execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_turn(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a single CoPilot turn (user message → AI response).
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The turn payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
processor: CoPilotProcessor = _tls.processor
|
||||
return processor.execute(entry, cancel, cluster_lock)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize the processor for the current worker thread.
|
||||
|
||||
This function is called by the thread pool executor when a new worker
|
||||
thread is created. It ensures each worker has its own processor instance.
|
||||
"""
|
||||
_tls.processor = CoPilotProcessor()
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
def cleanup_worker():
|
||||
"""Clean up the processor for the current worker thread.
|
||||
|
||||
Should be called before the worker thread's event loop is destroyed so
|
||||
that event-loop-bound resources (e.g. ``aiohttp.ClientSession``) are
|
||||
closed on the correct loop.
|
||||
"""
|
||||
processor: CoPilotProcessor | None = getattr(_tls, "processor", None)
|
||||
if processor is not None:
|
||||
processor.cleanup()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot sessions.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation sessions. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. Session entry is picked from RabbitMQ queue
|
||||
2. Manager submits to thread pool
|
||||
3. Processor executes in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@func_retry
|
||||
def on_executor_start(self):
|
||||
"""Initialize the processor when the worker thread starts.
|
||||
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
self.tid = threading.get_ident()
|
||||
self.execution_loop = asyncio.new_event_loop()
|
||||
self.execution_thread = threading.Thread(
|
||||
target=self.execution_loop.run_forever, daemon=True
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
Shuts down the workspace storage instance that belongs to this
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
coro = shutdown_workspace_storage()
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
coro.close() # Prevent "coroutine was never awaited" warning
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.warning(
|
||||
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
|
||||
)
|
||||
|
||||
# Stop the event loop
|
||||
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
|
||||
self.execution_thread.join(timeout=5)
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} cleaned up")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot turn.
|
||||
|
||||
Runs the async logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The turn payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for a CoPilot turn.
|
||||
|
||||
Calls the stream_chat_completion service function and publishes
|
||||
results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The turn payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
error_msg = None
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag
|
||||
config = ChatConfig()
|
||||
use_sdk = await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else copilot_service.stream_chat_completion
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
):
|
||||
if cancel.is_set():
|
||||
log.info("Cancel requested, breaking stream")
|
||||
break
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Skip StreamFinish — mark_session_completed publishes it.
|
||||
if isinstance(chunk, StreamFinish):
|
||||
continue
|
||||
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error publishing chunk {type(chunk).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Stream loop completed
|
||||
if cancel.is_set():
|
||||
log.info("Stream cancelled by user")
|
||||
|
||||
except BaseException as e:
|
||||
# Handle all exceptions (including CancelledError) with appropriate logging
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
log.info("Turn cancelled")
|
||||
error_msg = "Operation cancelled"
|
||||
else:
|
||||
error_msg = str(e) or type(e).__name__
|
||||
log.error(f"Turn failed: {error_msg}")
|
||||
raise
|
||||
finally:
|
||||
# If no exception but user cancelled, still mark as cancelled
|
||||
if not error_msg and cancel.is_set():
|
||||
error_msg = "Operation cancelled"
|
||||
try:
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=error_msg
|
||||
)
|
||||
except Exception as mark_err:
|
||||
log.error(f"Failed to mark session completed: {mark_err}")
|
||||
224
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
224
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""RabbitMQ queue configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues following the graph executor pattern:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Logging Helper ============ #
|
||||
|
||||
|
||||
class CoPilotLogMetadata(TruncatedLogger):
|
||||
"""Structured logging helper for CoPilot executor.
|
||||
|
||||
In cloud environments (structured logging enabled), uses a simple prefix
|
||||
and passes metadata via json_fields. In local environments, uses a detailed
|
||||
prefix with all metadata key-value pairs for easier debugging.
|
||||
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
max_length: int = 1000,
|
||||
**kwargs: str | None,
|
||||
):
|
||||
# Filter out None values
|
||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||
metadata["component"] = "CoPilotExecutor"
|
||||
|
||||
if is_structured_logging_enabled():
|
||||
prefix = "[CoPilotExecutor]"
|
||||
else:
|
||||
# Build prefix from metadata key-value pairs
|
||||
meta_parts = "|".join(
|
||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||
)
|
||||
prefix = (
|
||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ============ Exchange and Queue Configuration ============ #
|
||||
|
||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
name="copilot_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
name="copilot_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
|
||||
Returns:
|
||||
RabbitMQConfig with exchanges and queues defined
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
# ============ Message Models ============ #
|
||||
|
||||
|
||||
class CoPilotExecutionEntry(BaseModel):
|
||||
"""Task payload for CoPilot AI generation.
|
||||
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID (also used for dedup/locking)"""
|
||||
|
||||
turn_id: str = ""
|
||||
"""Per-turn UUID for Redis stream isolation"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
is_user_message: bool = True
|
||||
"""Whether the message is from the user (vs system/assistant)"""
|
||||
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
session_id: str
|
||||
"""Session ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_turn(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
message: str,
|
||||
turn_id: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID (also used for dedup/locking)
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
message: User's message to process
|
||||
turn_id: Per-turn UUID for Redis stream isolation
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(session_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot session.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(session_id=session_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
message=event.model_dump_json(),
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from typing import Any, Self, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
@@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -85,6 +55,19 @@ class ChatMessage(BaseModel):
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
|
||||
return ChatMessage(
|
||||
role=prisma_message.role,
|
||||
content=prisma_message.content,
|
||||
name=prisma_message.name,
|
||||
tool_call_id=prisma_message.toolCallId,
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
)
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
@@ -92,11 +75,10 @@ class Usage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
class ChatSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
@@ -104,60 +86,9 @@ class ChatSession(BaseModel):
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
Searches backwards for the most recent assistant message (stopping at
|
||||
any user message boundary). If found, appends the tool_call to it.
|
||||
Otherwise creates a new assistant message with the tool_call.
|
||||
"""
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == "user":
|
||||
break
|
||||
if msg.role == "assistant":
|
||||
if not msg.tool_calls:
|
||||
msg.tool_calls = []
|
||||
msg.tool_calls.append(tool_call)
|
||||
return
|
||||
|
||||
self.messages.append(
|
||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
@@ -179,11 +110,10 @@ class ChatSession(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
return ChatSession(
|
||||
return cls(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
@@ -192,46 +122,55 @@ class ChatSession(BaseModel):
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive_assistant_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Merge consecutive assistant messages into single messages.
|
||||
|
||||
Long-running tool flows can create split assistant messages: one with
|
||||
text content and another with tool_calls. Anthropic's API requires
|
||||
tool_result blocks to reference a tool_use in the immediately preceding
|
||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||
class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||
if prisma_session.Messages is None:
|
||||
raise ValueError(
|
||||
f"Prisma session {prisma_session.id} is missing Messages relation"
|
||||
)
|
||||
|
||||
return cls(
|
||||
**ChatSessionInfo.from_db(prisma_session).model_dump(),
|
||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||
)
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
Searches backwards for the most recent assistant message (stopping at
|
||||
any user message boundary). If found, appends the tool_call to it.
|
||||
Otherwise creates a new assistant message with the tool_call.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == "user":
|
||||
break
|
||||
if msg.role == "assistant":
|
||||
if not msg.tool_calls:
|
||||
msg.tool_calls = []
|
||||
msg.tool_calls.append(tool_call)
|
||||
return
|
||||
|
||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev = result[-1]
|
||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||
|
||||
curr_content = curr.get("content") or ""
|
||||
if curr_content:
|
||||
prev_content = prev.get("content") or ""
|
||||
prev["content"] = (
|
||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||
)
|
||||
|
||||
curr_tool_calls = curr.get("tool_calls")
|
||||
if curr_tool_calls:
|
||||
prev_tool_calls = prev.get("tool_calls")
|
||||
prev["tool_calls"] = (
|
||||
list(prev_tool_calls) + list(curr_tool_calls)
|
||||
if prev_tool_calls
|
||||
else list(curr_tool_calls)
|
||||
)
|
||||
return result
|
||||
self.messages.append(
|
||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
@@ -321,39 +260,68 @@ class ChatSession(BaseModel):
|
||||
)
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive_assistant_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Merge consecutive assistant messages into single messages.
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
Long-running tool flows can create split assistant messages: one with
|
||||
text content and another with tool_calls. Anthropic's API requires
|
||||
tool_result blocks to reference a tool_use in the immediately preceding
|
||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev = result[-1]
|
||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||
|
||||
curr_content = curr.get("content") or ""
|
||||
if curr_content:
|
||||
prev_content = prev.get("content") or ""
|
||||
prev["content"] = (
|
||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||
)
|
||||
|
||||
curr_tool_calls = curr.get("tool_calls")
|
||||
if curr_tool_calls:
|
||||
prev_tool_calls = prev.get("tool_calls")
|
||||
prev["tool_calls"] = (
|
||||
list(prev_tool_calls) + list(curr_tool_calls)
|
||||
if prev_tool_calls
|
||||
else list(curr_tool_calls)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
@@ -371,80 +339,6 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -476,7 +370,7 @@ async def get_chat_session(
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
@@ -492,7 +386,7 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
@@ -500,6 +394,44 @@ async def get_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
session = await chat_db().get_chat_session(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Loaded session {session_id} from DB: "
|
||||
f"has_messages={bool(session.messages)}, "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
@@ -519,16 +451,18 @@ async def upsert_chat_session(
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
# Always query DB for existing message count to ensure consistency
|
||||
existing_message_count = await chat_db().get_next_sequence(session.session_id)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
await _save_session_to_db(
|
||||
session,
|
||||
existing_message_count,
|
||||
skip_existence_check=existing_message_count > 0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
@@ -537,7 +471,7 @@ async def upsert_chat_session(
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -558,6 +492,107 @@ async def upsert_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession,
|
||||
existing_message_count: int,
|
||||
*,
|
||||
skip_existence_check: bool = False,
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database.
|
||||
|
||||
Args:
|
||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||
and assume the session row already exists. Saves one DB round trip
|
||||
for incremental saves during streaming.
|
||||
"""
|
||||
db = chat_db()
|
||||
|
||||
if not skip_existence_check:
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
Acquires the session lock, re-fetches the latest session state,
|
||||
appends the message, and saves — preventing message loss when
|
||||
concurrent requests modify the same session.
|
||||
"""
|
||||
lock = await _get_session_lock(session_id)
|
||||
|
||||
async with lock:
|
||||
session = await get_chat_session(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist message to session {session_id}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
@@ -570,7 +605,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -582,7 +617,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -593,20 +628,16 @@ async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
) -> tuple[list[ChatSessionInfo], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
db = chat_db()
|
||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
@@ -624,7 +655,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -659,20 +690,52 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
@@ -331,3 +331,96 @@ def test_to_openai_messages_merges_split_assistants():
|
||||
tc_list = merged.get("tool_calls")
|
||||
assert tc_list is not None and len(list(tc_list)) == 1
|
||||
assert list(tc_list)[0]["id"] == "tc1"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Concurrent save collision detection #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
|
||||
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
|
||||
|
||||
Simulates the race condition where:
|
||||
1. Streaming loop starts with saved_msg_count=5
|
||||
2. Long-running callback appends message #5 and saves
|
||||
3. Streaming loop tries to save with stale count=5
|
||||
|
||||
The collision detection should handle this gracefully.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Create a session with initial messages
|
||||
session = ChatSession.new(user_id=test_user_id)
|
||||
for i in range(3):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
|
||||
)
|
||||
)
|
||||
|
||||
# Save initial messages
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Simulate streaming loop and callback saving concurrently
|
||||
async def streaming_loop_save():
|
||||
"""Simulates streaming loop saving messages."""
|
||||
# Add 2 messages
|
||||
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content="Streaming message 2")
|
||||
)
|
||||
|
||||
# Wait a bit to let callback potentially save first
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Save (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
async def callback_save():
|
||||
"""Simulates long-running callback saving a message."""
|
||||
# Add 1 message
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
|
||||
)
|
||||
|
||||
# Save immediately (will query DB for existing count)
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
# Run both saves concurrently - one will hit collision detection
|
||||
results = await asyncio.gather(streaming_loop_save(), callback_save())
|
||||
|
||||
# Both should succeed
|
||||
assert all(r is not None for r in results)
|
||||
|
||||
# Reload session from DB to verify
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key) # Clear cache to force DB load
|
||||
|
||||
loaded_session = await get_chat_session(session.session_id, test_user_id)
|
||||
assert loaded_session is not None
|
||||
|
||||
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
|
||||
assert len(loaded_session.messages) == 6
|
||||
|
||||
# Verify no duplicate sequences
|
||||
sequences = []
|
||||
for i, msg in enumerate(loaded_session.messages):
|
||||
# Messages should have sequential sequence numbers starting from 0
|
||||
sequences.append(i)
|
||||
|
||||
# All sequences should be unique and sequential
|
||||
assert sequences == list(range(6))
|
||||
|
||||
# Verify message content is preserved
|
||||
contents = [m.content for m in loaded_session.messages]
|
||||
assert "Message 0" in contents
|
||||
assert "Message 1" in contents
|
||||
assert "Message 2" in contents
|
||||
assert "Streaming message 1" in contents
|
||||
assert "Streaming message 2" in contents
|
||||
assert "Callback result" in contents
|
||||
@@ -0,0 +1,269 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_shared_across_parallel_tools():
|
||||
"""All parallel tools should receive the same session instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_sessions = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
observed_sessions.append(sess)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_sessions) == 3
|
||||
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
@@ -5,6 +5,8 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +14,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from backend.util.json import dumps as json_dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
@@ -47,7 +51,8 @@ class StreamBaseResponse(BaseModel):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
json_str = self.model_dump_json(exclude_none=True)
|
||||
return f"data: {json_str}\n\n"
|
||||
|
||||
|
||||
# ========== Message Lifecycle ==========
|
||||
@@ -58,15 +63,13 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
sessionId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
description="Session ID for SSE reconnection.",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||
import json
|
||||
|
||||
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
|
||||
data: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"messageId": self.messageId,
|
||||
@@ -163,8 +166,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format, excluding non-spec fields."""
|
||||
import json
|
||||
|
||||
data = {
|
||||
"type": self.type.value,
|
||||
"toolCallId": self.toolCallId,
|
||||
14
autogpt_platform/backend/backend/copilot/sdk/__init__.py
Normal file
14
autogpt_platform/backend/backend/copilot/sdk/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Claude Agent SDK integration for CoPilot.
|
||||
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
"""
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
57
autogpt_platform/backend/backend/copilot/sdk/dummy.py
Normal file
57
autogpt_platform/backend/backend/copilot/sdk/dummy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream dummy chat completion for testing.
|
||||
|
||||
Returns a simple streaming response with text deltas to test:
|
||||
- Streaming infrastructure works
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
172
autogpt_platform/backend/backend/copilot/sdk/otel_setup_test.py
Normal file
172
autogpt_platform/backend/backend/copilot/sdk/otel_setup_test.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for OTEL tracing setup in the SDK copilot path."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestSetupLangfuseOtel:
|
||||
"""Tests for _setup_langfuse_otel()."""
|
||||
|
||||
def test_noop_when_langfuse_not_configured(self):
|
||||
"""No env vars should be set when Langfuse credentials are missing."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured", return_value=False
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Clear any previously set env vars
|
||||
env_keys = [
|
||||
"LANGSMITH_OTEL_ENABLED",
|
||||
"LANGSMITH_OTEL_ONLY",
|
||||
"LANGSMITH_TRACING",
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"OTEL_EXPORTER_OTLP_HEADERS",
|
||||
]
|
||||
saved = {k: os.environ.pop(k, None) for k in env_keys}
|
||||
try:
|
||||
_setup_langfuse_otel()
|
||||
for key in env_keys:
|
||||
assert key not in os.environ, f"{key} should not be set"
|
||||
finally:
|
||||
for k, v in saved.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
|
||||
def test_sets_env_vars_when_langfuse_configured(self):
|
||||
"""OTEL env vars should be set when Langfuse credentials exist."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.langfuse_public_key = "pk-test-123"
|
||||
mock_settings.secrets.langfuse_secret_key = "sk-test-456"
|
||||
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
|
||||
mock_settings.secrets.langfuse_tracing_environment = "test"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.configure_claude_agent_sdk",
|
||||
return_value=True,
|
||||
) as mock_configure,
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Clear env vars so setdefault works
|
||||
env_keys = [
|
||||
"LANGSMITH_OTEL_ENABLED",
|
||||
"LANGSMITH_OTEL_ONLY",
|
||||
"LANGSMITH_TRACING",
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"OTEL_EXPORTER_OTLP_HEADERS",
|
||||
"OTEL_RESOURCE_ATTRIBUTES",
|
||||
]
|
||||
saved = {k: os.environ.pop(k, None) for k in env_keys}
|
||||
try:
|
||||
_setup_langfuse_otel()
|
||||
|
||||
assert os.environ["LANGSMITH_OTEL_ENABLED"] == "true"
|
||||
assert os.environ["LANGSMITH_OTEL_ONLY"] == "true"
|
||||
assert os.environ["LANGSMITH_TRACING"] == "true"
|
||||
assert (
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
== "https://langfuse.example.com/api/public/otel"
|
||||
)
|
||||
assert "Authorization=Basic" in os.environ["OTEL_EXPORTER_OTLP_HEADERS"]
|
||||
assert (
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"]
|
||||
== "langfuse.environment=test"
|
||||
)
|
||||
|
||||
mock_configure.assert_called_once_with(tags=["sdk"])
|
||||
finally:
|
||||
for k, v in saved.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
def test_existing_env_vars_not_overwritten(self):
|
||||
"""Explicit env-var overrides should not be clobbered."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.secrets.langfuse_public_key = "pk-test"
|
||||
mock_settings.secrets.langfuse_secret_key = "sk-test"
|
||||
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.configure_claude_agent_sdk",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
saved = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
try:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "https://custom.endpoint/v1"
|
||||
_setup_langfuse_otel()
|
||||
assert (
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
== "https://custom.endpoint/v1"
|
||||
)
|
||||
finally:
|
||||
if saved is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = saved
|
||||
elif "OTEL_EXPORTER_OTLP_ENDPOINT" in os.environ:
|
||||
del os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
|
||||
|
||||
def test_graceful_failure_on_exception(self):
|
||||
"""Setup should not raise even if internal code fails."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service._is_langfuse_configured",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.Settings",
|
||||
side_effect=RuntimeError("settings unavailable"),
|
||||
),
|
||||
):
|
||||
from backend.copilot.sdk.service import _setup_langfuse_otel
|
||||
|
||||
# Should not raise — just logs and returns
|
||||
_setup_langfuse_otel()
|
||||
|
||||
|
||||
class TestPropagateAttributesImport:
|
||||
"""Verify langfuse.propagate_attributes is available."""
|
||||
|
||||
def test_propagate_attributes_is_importable(self):
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
assert callable(propagate_attributes)
|
||||
|
||||
def test_propagate_attributes_returns_context_manager(self):
|
||||
from langfuse import propagate_attributes
|
||||
|
||||
ctx = propagate_attributes(user_id="u1", session_id="s1", tags=["test"])
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
|
||||
|
||||
class TestReceiveResponseCompat:
|
||||
"""Verify ClaudeSDKClient.receive_response() exists (langsmith patches it)."""
|
||||
|
||||
def test_receive_response_exists(self):
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "receive_response")
|
||||
|
||||
def test_receive_response_is_async_generator(self):
|
||||
import inspect
|
||||
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
method = getattr(ClaudeSDKClient, "receive_response")
|
||||
assert inspect.isfunction(method) or inspect.ismethod(method)
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Tests for _format_conversation_context and _build_query_message."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import (
|
||||
_build_query_message,
|
||||
_format_conversation_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_conversation_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_format_empty_list():
|
||||
assert _format_conversation_context([]) is None
|
||||
|
||||
|
||||
def test_format_none_content_messages():
|
||||
msgs = [ChatMessage(role="user", content=None)]
|
||||
assert _format_conversation_context(msgs) is None
|
||||
|
||||
|
||||
def test_format_user_message():
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: hello" in result
|
||||
assert result.startswith("<conversation_history>")
|
||||
assert result.endswith("</conversation_history>")
|
||||
|
||||
|
||||
def test_format_assistant_text():
|
||||
msgs = [ChatMessage(role="assistant", content="hi there")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "You responded: hi there" in result
|
||||
|
||||
|
||||
def test_format_assistant_tool_calls():
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[{"function": {"name": "search", "arguments": '{"q": "test"}'}}],
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool result: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="find agents"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I'll search for agents.",
|
||||
tool_calls=[
|
||||
{"function": {"name": "find_agents", "arguments": '{"q": "test"}'}}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content='[{"id": "1", "name": "Agent1"}]'),
|
||||
ChatMessage(role="assistant", content="Found Agent1."),
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_query_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
"""Build a minimal ChatSession with the given messages."""
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_up_to_date():
|
||||
"""With --resume and transcript covers all messages, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="what's new?"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"what's new?",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
# transcript_msg_count == msg_count - 1, so no gap
|
||||
assert result == "what's new?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_stale_transcript():
|
||||
"""With --resume and stale transcript, gap context is prepended."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="turn 1"),
|
||||
ChatMessage(role="assistant", content="reply 1"),
|
||||
ChatMessage(role="user", content="turn 2"),
|
||||
ChatMessage(role="assistant", content="reply 2"),
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "turn 2" in result
|
||||
assert "reply 2" in result
|
||||
assert "Now, the user says:\nturn 3" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_zero_msg_count():
|
||||
"""With --resume but transcript_msg_count=0, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="new msg"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"new msg",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "new msg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_single_message():
|
||||
"""Without --resume and only 1 message, return raw message."""
|
||||
session = _make_session([ChatMessage(role="user", content="first")])
|
||||
result = await _build_query_message(
|
||||
"first",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "first"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
"""Without --resume and multiple messages, compress and prepend."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="older question"),
|
||||
ChatMessage(role="assistant", content="older answer"),
|
||||
ChatMessage(role="user", content="new question"),
|
||||
]
|
||||
)
|
||||
|
||||
# Mock _compress_conversation_history to return the messages as-is
|
||||
async def _mock_compress(sess):
|
||||
return sess.messages[:-1]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_conversation_history",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "older question" in result
|
||||
assert "older answer" in result
|
||||
assert "Now, the user says:\nnew question" in result
|
||||
371
autogpt_platform/backend/backend/copilot/sdk/response_adapter.py
Normal file
371
autogpt_platform/backend/backend/copilot/sdk/response_adapter.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None, session_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.session_id = session_id
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.step_open = False
|
||||
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
"""True when there are tool calls that haven't received output yet."""
|
||||
return bool(self.current_tool_calls.keys() - self.resolved_tool_calls)
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, sessionId=self.session_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
# BUT skip flush when this AssistantMessage is a parallel tool
|
||||
# continuation (contains only ToolUseBlocks) — the prior tools
|
||||
# are still executing concurrently and haven't finished yet.
|
||||
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
|
||||
if not is_tool_only:
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
responses.append(
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||
)
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=tool_name,
|
||||
input=block.input,
|
||||
)
|
||||
)
|
||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||
|
||||
elif isinstance(sdk_message, UserMessage):
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
resolved_in_blocks: set[str] = set()
|
||||
|
||||
sid = (self.session_id or "?")[:12]
|
||||
parent_id_preview = getattr(sdk_message, "parent_tool_use_id", None)
|
||||
logger.info(
|
||||
"[SDK] [%s] UserMessage: %d blocks, content_type=%s, "
|
||||
"parent_tool_use_id=%s",
|
||||
sid,
|
||||
len(blocks),
|
||||
type(content).__name__,
|
||||
parent_id_preview[:12] if parent_id_preview else "None",
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
# Skip if already resolved (e.g. by flush) — the real
|
||||
# result supersedes the empty flush, but re-emitting
|
||||
# would confuse the frontend's state machine.
|
||||
if block.tool_use_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Prefer the stashed full output over the SDK's
|
||||
# (potentially truncated) ToolResultBlock content.
|
||||
# The SDK truncates large results, writing them to disk,
|
||||
# which breaks frontend widget parsing.
|
||||
output = pop_pending_tool_output(tool_name) or (
|
||||
_extract_tool_output(block.content)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=block.tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(block.tool_use_id)
|
||||
|
||||
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||
# instead of (or in addition to) ToolResultBlock content.
|
||||
parent_id = sdk_message.parent_tool_use_id
|
||||
if (
|
||||
parent_id
|
||||
and parent_id not in resolved_in_blocks
|
||||
and parent_id not in self.resolved_tool_calls
|
||||
):
|
||||
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Try stashed output first (from PostToolUse hook),
|
||||
# then tool_use_result dict, then string content.
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if not output:
|
||||
tur = sdk_message.tool_use_result
|
||||
if tur is not None:
|
||||
output = _extract_tool_use_result(tur)
|
||||
if not output and isinstance(content, str) and content.strip():
|
||||
output = content.strip()
|
||||
|
||||
if output:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=parent_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(parent_id)
|
||||
|
||||
self.resolved_tool_calls.update(resolved_in_blocks)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current text block if one is open."""
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
|
||||
internally without surfacing a separate ``UserMessage`` with
|
||||
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
"""
|
||||
unresolved = [
|
||||
(tid, info.get("name", "unknown"))
|
||||
for tid, info in self.current_tool_calls.items()
|
||||
if tid not in self.resolved_tool_calls
|
||||
]
|
||||
sid = (self.session_id or "?")[:12]
|
||||
if not unresolved:
|
||||
logger.info(
|
||||
"[SDK] [%s] Flush called but all %d tool(s) already resolved",
|
||||
sid,
|
||||
len(self.current_tool_calls),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushing %d unresolved tool call(s): %s",
|
||||
sid,
|
||||
len(unresolved),
|
||||
", ".join(f"{name}({tid[:12]})" for tid, name in unresolved),
|
||||
)
|
||||
|
||||
flushed = False
|
||||
for tool_id, tool_name in unresolved:
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if output is not None:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
len(output),
|
||||
)
|
||||
else:
|
||||
# No output available — emit an empty output so the frontend
|
||||
# transitions the tool from input-available to output-available
|
||||
# (stops the spinner).
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output="",
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.warning(
|
||||
"[SDK] [%s] Flushed EMPTY output for unresolved tool %s "
|
||||
"(call %s) — stash was empty (likely SDK hook race "
|
||||
"condition: PostToolUse hook hadn't completed before "
|
||||
"flush was triggered)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_tool_use_result(result: object) -> str:
|
||||
"""Extract a string from a UserMessage's ``tool_use_result`` dict.
|
||||
|
||||
SDK built-in tools may store their result in ``tool_use_result``
|
||||
instead of (or in addition to) ``ToolResultBlock`` content blocks.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
# Try common result keys
|
||||
for key in ("content", "text", "output", "stdout", "result"):
|
||||
val = result.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
# Fall back to JSON serialization of the whole dict
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
if result is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
@@ -0,0 +1,678 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
from .tool_adapter import _stash_event
|
||||
from .tool_adapter import stash_pending_tool_output as _stash
|
||||
from .tool_adapter import wait_for_stash
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
|
||||
|
||||
|
||||
# -- SystemMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_system_init_emits_start_and_step():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamStart)
|
||||
assert results[0].messageId == "msg-1"
|
||||
assert results[0].sessionId == "session-1"
|
||||
assert isinstance(results[1], StreamStartStep)
|
||||
|
||||
|
||||
def test_system_non_init_emits_nothing():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
def test_text_block_emits_step_start_and_delta():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamTextStart)
|
||||
assert isinstance(results[2], StreamTextDelta)
|
||||
assert results[2].delta == "hello"
|
||||
|
||||
|
||||
def test_empty_text_block_emits_only_step():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
# Empty text skipped, but step still opens
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
|
||||
|
||||
def test_multiple_text_deltas_reuse_block_id():
|
||||
adapter = _adapter()
|
||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
||||
r1 = adapter.convert_message(msg1)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
# First gets step+start+delta, second only delta (block & step already started)
|
||||
assert len(r1) == 3
|
||||
assert isinstance(r1[0], StreamStartStep)
|
||||
assert isinstance(r1[1], StreamTextStart)
|
||||
assert len(r2) == 1
|
||||
assert isinstance(r2[0], StreamTextDelta)
|
||||
assert r1[1].id == r2[0].id # same block ID
|
||||
|
||||
|
||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
||||
|
||||
|
||||
def test_tool_use_emits_input_start_and_available():
|
||||
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="tool-1",
|
||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||
input={"q": "x"},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
assert results[1].toolCallId == "tool-1"
|
||||
assert results[1].toolName == "find_agent" # prefix stripped
|
||||
assert isinstance(results[2], StreamToolInputAvailable)
|
||||
assert results[2].toolName == "find_agent" # prefix stripped
|
||||
assert results[2].input == {"q": "x"}
|
||||
|
||||
|
||||
def test_text_then_tool_ends_text_block():
|
||||
adapter = _adapter()
|
||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(text_msg) # opens step + text
|
||||
results = adapter.convert_message(tool_msg)
|
||||
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
|
||||
|
||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
||||
|
||||
|
||||
def test_tool_result_emits_output_and_finish_step():
|
||||
adapter = _adapter()
|
||||
# First register the tool call (opens step) — SDK sends prefixed name
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(tool_msg)
|
||||
|
||||
# Now send tool result
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].toolCallId == "t1"
|
||||
assert results[0].toolName == "find_agent" # prefix stripped
|
||||
assert results[0].output == "found 3 agents"
|
||||
assert results[0].success is True
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_tool_result_error():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].success is False
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_tool_result_list_content():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[
|
||||
ToolResultBlock(
|
||||
tool_use_id="t1",
|
||||
content=[
|
||||
{"type": "text", "text": "line1"},
|
||||
{"type": "text", "text": "line2"},
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].output == "line1line2"
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_string_user_message_ignored():
|
||||
"""A plain string UserMessage (not tool results) produces no output."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(UserMessage(content="hello"))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- ResultMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_result_success_emits_finish_step_and_finish():
|
||||
adapter = _adapter()
|
||||
# Start some text first (opens step)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
||||
)
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# TextEnd + FinishStep + StreamFinish
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
assert isinstance(results[2], StreamFinish)
|
||||
|
||||
|
||||
def test_result_error_emits_error_and_finish():
|
||||
adapter = _adapter()
|
||||
msg = ResultMessage(
|
||||
subtype="error",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# No step was open, so no FinishStep — just Error + Finish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
# -- Text after tools (new block ID) ----------------------------------------
|
||||
|
||||
|
||||
def test_text_after_tool_gets_new_block_id():
|
||||
adapter = _adapter()
|
||||
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
# Send tool result (closes step)
|
||||
adapter.convert_message(
|
||||
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
||||
)
|
||||
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamTextStart)
|
||||
assert isinstance(results[2], StreamTextDelta)
|
||||
assert results[2].delta == "after"
|
||||
|
||||
|
||||
# -- Full conversation flow --------------------------------------------------
|
||||
|
||||
|
||||
def test_full_conversation_flow():
|
||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Assistant text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
||||
)
|
||||
)
|
||||
# 3. Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="t1",
|
||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||
input={"query": "email"},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 4. Tool result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
||||
)
|
||||
)
|
||||
)
|
||||
# 5. More text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
||||
)
|
||||
)
|
||||
# 6. Result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=500,
|
||||
duration_api_ms=400,
|
||||
is_error=False,
|
||||
num_turns=2,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep", # step 1: text + tool call
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta", # "Let me search"
|
||||
"StreamTextEnd", # closed before tool
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # tool result
|
||||
"StreamFinishStep", # step 1 closed after tool result
|
||||
"StreamStartStep", # step 2: continuation text
|
||||
"StreamTextStart", # new block after tool
|
||||
"StreamTextDelta", # "I found 2"
|
||||
"StreamTextEnd", # closed by result
|
||||
"StreamFinishStep", # step 2 closed
|
||||
"StreamFinish",
|
||||
]
|
||||
|
||||
|
||||
# -- Flush unresolved tool calls --------------------------------------------
|
||||
|
||||
|
||||
def test_flush_unresolved_at_result_message():
|
||||
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in tool — no MCP prefix)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. No UserMessage for this tool — go straight to ResultMessage
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # flushed with empty output
|
||||
"StreamFinishStep", # step closed by flush
|
||||
"StreamFinish",
|
||||
]
|
||||
# The flushed output should be empty (no stash available)
|
||||
output_event = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
][0]
|
||||
assert output_event.toolCallId == "ws-1"
|
||||
assert output_event.toolName == "WebSearch"
|
||||
assert output_event.output == ""
|
||||
|
||||
|
||||
def test_flush_unresolved_at_next_assistant_message():
|
||||
"""Built-in tools get flushed when the next AssistantMessage arrives."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in — no UserMessage will come)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. Next AssistantMessage triggers flush before processing its blocks
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")], model="test"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
# Flush at next AssistantMessage:
|
||||
"StreamToolOutputAvailable",
|
||||
"StreamFinishStep", # step closed by flush
|
||||
# New step for continuation text:
|
||||
"StreamStartStep",
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta",
|
||||
]
|
||||
|
||||
|
||||
def test_flush_with_stashed_output():
|
||||
"""Stashed output from PostToolUse hook is used when flushing."""
|
||||
adapter = _adapter()
|
||||
|
||||
# Simulate PostToolUse hook stashing output
|
||||
_pto.set({})
|
||||
_stash("WebSearch", "Search result: 5 items found")
|
||||
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# ResultMessage triggers flush
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_events = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].output == "Search result: 5 items found"
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# -- wait_for_stash synchronisation tests --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_signaled():
|
||||
"""wait_for_stash returns True when stash_pending_tool_output signals."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Simulate a PostToolUse hook that stashes output after a short delay
|
||||
async def delayed_stash():
|
||||
await asyncio.sleep(0.01)
|
||||
_stash("WebSearch", "result data")
|
||||
|
||||
asyncio.create_task(delayed_stash())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
|
||||
assert result is True
|
||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_timeout():
|
||||
"""wait_for_stash returns False on timeout when no stash occurs."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_already_stashed():
|
||||
"""wait_for_stash picks up a stash that happened just before the wait."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Stash before waiting — simulates hook completing before message arrives
|
||||
_stash("Read", "file contents")
|
||||
# Event is now set; wait_for_stash detects the fast path and returns
|
||||
# immediately without timing out.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is True
|
||||
|
||||
# But the stash itself is populated
|
||||
assert _pto.get({}).get("Read") == ["file contents"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
# -- Parallel tool call tests --
|
||||
|
||||
|
||||
def test_parallel_tool_calls_not_flushed_prematurely():
|
||||
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
|
||||
only contains ToolUseBlocks (parallel continuation)."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# First AssistantMessage: tool call #1
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
r1 = adapter.convert_message(msg1)
|
||||
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Second AssistantMessage: tool call #2 (parallel continuation)
|
||||
msg2 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 0, (
|
||||
f"Tool-only AssistantMessage should not flush prior tools, "
|
||||
f"but got {len(output_events)} output events"
|
||||
)
|
||||
|
||||
# Both t1 and t2 should still be unresolved
|
||||
assert "t1" not in adapter.resolved_tool_calls
|
||||
assert "t2" not in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_text_assistant_message_flushes_prior_tools():
|
||||
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(msg1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Text AssistantMessage (new turn after tools completed)
|
||||
msg2 = AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# Flush SHOULD have happened — t1 gets empty output
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].toolCallId == "t1"
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_already_resolved_tool_skipped_in_user_message():
|
||||
"""A tool result in UserMessage should be skipped if already resolved by flush."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call + flush via text message
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Done")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
# Now UserMessage arrives with the real result — should be skipped
|
||||
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
|
||||
r = adapter.convert_message(user_msg)
|
||||
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
194
autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py
Normal file
194
autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""SDK compatibility tests — verify the claude-agent-sdk public API surface we depend on.
|
||||
|
||||
Instead of pinning to a narrow version range, these tests verify that the
|
||||
installed SDK exposes every class, function, attribute, and method the copilot
|
||||
integration relies on. If an SDK upgrade removes or renames something these
|
||||
tests will catch it immediately.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types & factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sdk_exports_client_and_options():
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
assert inspect.isclass(ClaudeSDKClient)
|
||||
assert inspect.isclass(ClaudeAgentOptions)
|
||||
|
||||
|
||||
def test_sdk_exports_message_types():
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
for cls in (AssistantMessage, ResultMessage, SystemMessage, UserMessage):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
# Message is a Union type alias, just verify it's importable
|
||||
assert Message is not None
|
||||
|
||||
|
||||
def test_sdk_exports_content_block_types():
|
||||
from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock
|
||||
|
||||
for cls in (TextBlock, ToolResultBlock, ToolUseBlock):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
|
||||
|
||||
def test_sdk_exports_mcp_helpers():
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
assert callable(create_sdk_mcp_server)
|
||||
assert callable(tool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeSDKClient interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_client_has_required_methods():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
required = ["connect", "disconnect", "query", "receive_messages"]
|
||||
for name in required:
|
||||
attr = getattr(ClaudeSDKClient, name, None)
|
||||
assert attr is not None, f"ClaudeSDKClient.{name} missing"
|
||||
assert callable(attr), f"ClaudeSDKClient.{name} is not callable"
|
||||
|
||||
|
||||
def test_client_supports_async_context_manager():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "__aenter__")
|
||||
assert hasattr(ClaudeSDKClient, "__aexit__")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeAgentOptions fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_options_accepts_required_fields():
|
||||
"""Verify ClaudeAgentOptions accepts all kwargs our code passes."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
opts = ClaudeAgentOptions(
|
||||
system_prompt="test",
|
||||
cwd="/tmp",
|
||||
)
|
||||
assert opts.system_prompt == "test"
|
||||
assert opts.cwd == "/tmp"
|
||||
|
||||
|
||||
def test_agent_options_accepts_all_our_fields():
|
||||
"""Comprehensive check of every field we use in service.py."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
fields_we_use = [
|
||||
"system_prompt",
|
||||
"mcp_servers",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"hooks",
|
||||
"cwd",
|
||||
"model",
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
assert field in sig.parameters, (
|
||||
f"ClaudeAgentOptions no longer accepts '{field}' — "
|
||||
f"available params: {list(sig.parameters.keys())}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_assistant_message_has_content_and_model():
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock
|
||||
|
||||
msg = AssistantMessage(content=[TextBlock(text="hi")], model="test")
|
||||
assert hasattr(msg, "content")
|
||||
assert hasattr(msg, "model")
|
||||
|
||||
|
||||
def test_result_message_has_required_attrs():
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
assert msg.subtype == "success"
|
||||
assert hasattr(msg, "result")
|
||||
|
||||
|
||||
def test_system_message_has_subtype_and_data():
|
||||
from claude_agent_sdk import SystemMessage
|
||||
|
||||
msg = SystemMessage(subtype="init", data={})
|
||||
assert msg.subtype == "init"
|
||||
assert msg.data == {}
|
||||
|
||||
|
||||
def test_user_message_has_parent_tool_use_id():
|
||||
from claude_agent_sdk import UserMessage
|
||||
|
||||
msg = UserMessage(content="test")
|
||||
assert hasattr(msg, "parent_tool_use_id")
|
||||
assert hasattr(msg, "tool_use_result")
|
||||
|
||||
|
||||
def test_tool_use_block_has_id_name_input():
|
||||
from claude_agent_sdk import ToolUseBlock
|
||||
|
||||
block = ToolUseBlock(id="t1", name="test", input={"key": "val"})
|
||||
assert block.id == "t1"
|
||||
assert block.name == "test"
|
||||
assert block.input == {"key": "val"}
|
||||
|
||||
|
||||
def test_tool_result_block_has_required_attrs():
|
||||
from claude_agent_sdk import ToolResultBlock
|
||||
|
||||
block = ToolResultBlock(tool_use_id="t1", content="result")
|
||||
assert block.tool_use_id == "t1"
|
||||
assert block.content == "result"
|
||||
assert hasattr(block, "is_error")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hook_event",
|
||||
["PreToolUse", "PostToolUse", "Stop"],
|
||||
)
|
||||
def test_sdk_exports_hook_event_type(hook_event: str):
|
||||
"""Verify HookEvent literal includes the events our security_hooks use."""
|
||||
from claude_agent_sdk.types import HookEvent
|
||||
|
||||
# HookEvent is a Literal type — check that our events are valid values.
|
||||
# We can't easily inspect Literal at runtime, so just verify the type exists.
|
||||
assert HookEvent is not None
|
||||
371
autogpt_platform/backend/backend/copilot/sdk/security_hooks.py
Normal file
371
autogpt_platform/backend/backend/copilot/sdk/security_hooks.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Security hooks for Claude Agent SDK integration.
|
||||
|
||||
This module provides security hooks that validate tool calls before execution,
|
||||
ensuring multi-user isolation and preventing unauthorized operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from .tool_adapter import (
|
||||
BLOCKED_TOOLS,
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": reason,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _validate_workspace_path(
|
||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||
|
||||
Allowed directories:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||
return {}
|
||||
|
||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
||||
if path.startswith("~"):
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
elif not os.path.isabs(path) and sdk_cwd:
|
||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
||||
else:
|
||||
resolved = os.path.realpath(path)
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.realpath(sdk_cwd)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
return {}
|
||||
|
||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
||||
return {}
|
||||
|
||||
logger.warning(
|
||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||
)
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
f"directory.{workspace_hint} "
|
||||
"This is enforced by the platform and cannot be bypassed."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tool_access(
|
||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a tool call is allowed.
|
||||
|
||||
Returns:
|
||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
"Use the CoPilot-specific MCP tools instead."
|
||||
)
|
||||
|
||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
||||
|
||||
# Check for dangerous patterns in tool input
|
||||
# Use json.dumps for predictable format (str() produces Python repr)
|
||||
input_str = json.dumps(tool_input) if tool_input else ""
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return _deny(
|
||||
"[SECURITY] Input contains a blocked pattern. "
|
||||
"This is enforced by the platform and cannot be bypassed."
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _validate_user_isolation(
|
||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
# The "path" param is a cloud storage key (e.g. "/ASEAN/report.md")
|
||||
# where a leading "/" is normal. Only check for ".." traversal.
|
||||
# Filesystem paths (source_path, save_to_path) are validated inside
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_security_hooks(
|
||||
user_id: str | None,
|
||||
sdk_cwd: str | None = None,
|
||||
max_subtasks: int = 3,
|
||||
on_stop: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
Includes security validation and observability hooks:
|
||||
- PreToolUse: Security validation before tool execution
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
||||
the SDK finishes processing — used to read the JSONL transcript
|
||||
before the CLI process exits.
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
# Per-session tracking for Task sub-agent concurrency.
|
||||
# Set of tool_use_ids that consumed a slot — len() is the active count.
|
||||
task_tool_use_ids: set[str] = set()
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Combined pre-tool-use validation hook."""
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
f"Maximum {max_subtasks} concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
"or continue in the main conversation."
|
||||
),
|
||||
)
|
||||
|
||||
# Strip MCP prefix for consistent validation
|
||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
# Only block non-CoPilot tools; our MCP-registered tools
|
||||
# (including Read for oversized results) are already sandboxed.
|
||||
if not is_copilot_tool:
|
||||
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Validate user isolation
|
||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Reserve the Task slot only after all validations pass
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a Task concurrency slot if one was reserved."""
|
||||
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
|
||||
task_tool_use_ids.discard(tool_use_id)
|
||||
logger.info(
|
||||
"[SDK] Task slot released, active=%d/%d, user=%s",
|
||||
len(task_tool_use_ids),
|
||||
max_subtasks,
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def post_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions and stash SDK built-in tool outputs.
|
||||
|
||||
MCP tools stash their output in ``_execute_tool_sync`` before the
|
||||
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
|
||||
are executed by the CLI internally — this hook captures their
|
||||
output so the response adapter can forward it to the frontend.
|
||||
"""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||
# a separate UserMessage with ToolResultBlock content.
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
len(str(tool_response)),
|
||||
resp_preview,
|
||||
)
|
||||
stash_pending_tool_output(tool_name, tool_response)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] PostToolUse for builtin %s but tool_response is None",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def pre_compact_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when SDK triggers context compaction.
|
||||
|
||||
The SDK automatically compacts conversation history when it grows too large.
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
# --- Stop hook: capture transcript path for stateless resume ---
|
||||
async def stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Capture transcript path when SDK finishes processing.
|
||||
|
||||
The Stop hook fires while the CLI process is still alive, giving us
|
||||
a reliable window to read the JSONL transcript before SIGTERM.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
||||
|
||||
if transcript_path and on_stop:
|
||||
logger.info(
|
||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
||||
)
|
||||
on_stop(transcript_path, sdk_session_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
"PostToolUseFailure": [
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
|
||||
if on_stop is not None:
|
||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
||||
|
||||
return hooks
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
||||
return {}
|
||||
@@ -0,0 +1,418 @@
|
||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
||||
|
||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
||||
They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .service import _is_tool_error_or_denial
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
|
||||
def _sdk_available() -> bool:
|
||||
try:
|
||||
import claude_agent_sdk # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
|
||||
|
||||
def _reason(result: dict) -> str:
|
||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
||||
|
||||
|
||||
# -- Blocked tools -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tools_denied():
|
||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||
result = _validate_tool_access(tool, {})
|
||||
assert _is_denied(result), f"{tool} should be blocked"
|
||||
|
||||
|
||||
def test_unknown_tool_allowed():
|
||||
result = _validate_tool_access("SomeCustomTool", {})
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Workspace-scoped tools --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_write_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_edit_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_glob_within_workspace_allowed():
|
||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_grep_within_workspace_allowed():
|
||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_write_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_traversal_attack_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read",
|
||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_no_path_allowed():
|
||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_no_cwd_denies_absolute():
|
||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Tool-results directory --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
|
||||
def test_bash_builtin_always_blocked():
|
||||
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Dangerous patterns ------------------------------------------------------
|
||||
|
||||
|
||||
def test_dangerous_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_subprocess_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- User isolation ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_workspace_path_traversal_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_allowed():
|
||||
"""Workspace 'path' is a cloud storage key — leading '/' is normal."""
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/ASEAN/report.md"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_non_workspace_tool_passes_isolation():
|
||||
result = _validate_user_isolation(
|
||||
"find_agent", {"query": "email"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Deny message quality ----------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tool_message_clarity():
|
||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'."""
|
||||
reason = _reason(_validate_tool_access("bash", {}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
|
||||
def test_bash_builtin_blocked_message_clarity():
|
||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
|
||||
# -- Task sub-agent hooks (require SDK) --------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _hooks():
|
||||
"""Create security hooks and return (pre, post, post_failure) handlers."""
|
||||
from .security_hooks import create_security_hooks
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
pre = hooks["PreToolUse"][0].hooks[0]
|
||||
post = hooks["PostToolUse"][0].hooks[0]
|
||||
post_failure = hooks["PostToolUseFailure"][0].hooks[0]
|
||||
return pre, post, post_failure
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_foreground_allowed(_hooks):
|
||||
"""Task without run_in_background should be allowed."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id="tu-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_limit_enforced(_hooks):
|
||||
"""Task spawns beyond max_subtasks should be denied."""
|
||||
pre, _, _ = _hooks
|
||||
# First two should pass
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-limit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id="tu-limit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_slot_released_on_completion(_hooks):
|
||||
"""Completing a Task should free a slot so new Tasks can be spawned."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-comp-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied — at capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-comp-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Complete first task — frees a slot
|
||||
await post(
|
||||
{"tool_name": "Task", "tool_input": {}},
|
||||
tool_use_id="tu-comp-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now a new Task should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "after release"}},
|
||||
tool_use_id="tu-comp-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_slot_released_on_failure(_hooks):
|
||||
"""A failed Task should also free its concurrency slot."""
|
||||
pre, _, post_failure = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-fail-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# At capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-fail-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Fail first task — should free a slot
|
||||
await post_failure(
|
||||
{"tool_name": "Task", "tool_input": {}, "error": "something broke"},
|
||||
tool_use_id="tu-fail-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# New Task should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "after failure"}},
|
||||
tool_use_id="tu-fail-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# -- _is_tool_error_or_denial ------------------------------------------------
|
||||
|
||||
|
||||
class TestIsToolErrorOrDenial:
|
||||
def test_none_content(self):
|
||||
assert _is_tool_error_or_denial(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _is_tool_error_or_denial("") is False
|
||||
|
||||
def test_benign_output(self):
|
||||
assert _is_tool_error_or_denial("All good, no issues.") is False
|
||||
|
||||
def test_security_marker(self):
|
||||
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
|
||||
|
||||
def test_cannot_be_bypassed(self):
|
||||
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
|
||||
|
||||
def test_not_allowed(self):
|
||||
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
|
||||
|
||||
def test_background_task_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subtask_limit_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Maximum 2 concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
"or continue in the main conversation."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_denied_marker(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
|
||||
)
|
||||
|
||||
def test_blocked_marker(self):
|
||||
assert _is_tool_error_or_denial("Request blocked by security policy") is True
|
||||
|
||||
def test_failed_marker(self):
|
||||
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
|
||||
|
||||
def test_mcp_iserror(self):
|
||||
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
|
||||
|
||||
def test_benign_error_in_value(self):
|
||||
"""Content like '0 errors found' should not trigger — 'error' was removed."""
|
||||
assert _is_tool_error_or_denial("0 errors found") is False
|
||||
|
||||
def test_benign_permission_field(self):
|
||||
"""Schema descriptions mentioning 'permission' should not trigger."""
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
'{"fields": [{"name": "permission_level", "type": "int"}]}'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_benign_not_found_in_listing(self):
|
||||
"""File listing containing 'not found' in filenames should not trigger."""
|
||||
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False
|
||||
1265
autogpt_platform/backend/backend/copilot/sdk/service.py
Normal file
1265
autogpt_platform/backend/backend/copilot/sdk/service.py
Normal file
File diff suppressed because it is too large
Load Diff
463
autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py
Normal file
463
autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
"pending_tool_outputs",
|
||||
default=None, # type: ignore[arg-type]
|
||||
)
|
||||
# Event signaled whenever stash_pending_tool_output() adds a new entry.
|
||||
# Used by the streaming loop to wait for PostToolUse hooks to complete
|
||||
# instead of sleeping an arbitrary duration. The SDK fires hooks via
|
||||
# start_soon (fire-and-forget) so the next message can arrive before
|
||||
# the hook stashes its output — this event bridges that gap.
|
||||
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
|
||||
Args:
|
||||
user_id: Current user's ID.
|
||||
session: Current chat session.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
)
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the oldest stashed output for *tool_name*.
|
||||
|
||||
The SDK CLI may truncate large tool results (writing them to disk and
|
||||
replacing the content with a file reference). This stash keeps the
|
||||
original MCP output so the response adapter can forward it to the
|
||||
frontend for proper widget rendering.
|
||||
|
||||
Uses a FIFO queue per tool name so duplicate calls to the same tool
|
||||
in one turn each get their own output.
|
||||
|
||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return None
|
||||
queue = pending.get(tool_name)
|
||||
if not queue:
|
||||
pending.pop(tool_name, None)
|
||||
return None
|
||||
value = queue.pop(0)
|
||||
if not queue:
|
||||
del pending[tool_name]
|
||||
return value
|
||||
|
||||
|
||||
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
"""Stash tool output for later retrieval by the response adapter.
|
||||
|
||||
Used by the PostToolUse hook to capture SDK built-in tool outputs
|
||||
(WebSearch, Read, etc.) that aren't available through the MCP stash
|
||||
mechanism in ``_execute_tool_sync``.
|
||||
|
||||
Appends to a FIFO queue per tool name so multiple calls to the same
|
||||
tool in one turn are all preserved.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return
|
||||
if isinstance(output, str):
|
||||
text = output
|
||||
else:
|
||||
try:
|
||||
text = json.dumps(output)
|
||||
except (TypeError, ValueError):
|
||||
text = str(output)
|
||||
pending.setdefault(tool_name, []).append(text)
|
||||
# Signal any waiters that new output is available.
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
the next message (AssistantMessage/ResultMessage) can arrive before the
|
||||
hook completes and stashes its output. This function bridges that gap
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
return False
|
||||
# Fast path: hook already completed before we got here.
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
event.clear()
|
||||
return True
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
base_tool: BaseTool,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=effective_id,
|
||||
**args,
|
||||
)
|
||||
|
||||
text = (
|
||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
||||
)
|
||||
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending.setdefault(base_tool.name, []).append(text)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
|
||||
return {
|
||||
"content": content_blocks,
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
# Only inline small images — large ones would exceed Claude's limits.
|
||||
# 32 KB raw ≈ ~43 KB base64.
|
||||
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||
if (
|
||||
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||
and base64_content
|
||||
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||
):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
return _mcp_error("No session context available")
|
||||
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
|
||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
"""Build a JSON Schema input schema for a tool."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": base_tool.parameters.get("properties", {}),
|
||||
"required": base_tool.parameters.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||
|
||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(real_path) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
content = "".join(selected)
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
_READ_TOOL_NAME = "Read"
|
||||
_READ_TOOL_DESCRIPTION = (
|
||||
"Read a file from the local filesystem. "
|
||||
"Use offset and limit to read specific line ranges for large files."
|
||||
)
|
||||
_READ_TOOL_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to read",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (0-indexed). Default: 0",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
}
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Add the Read tool so the SDK can read back oversized tool results
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_read_file_handler)
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
|
||||
|
||||
# SDK built-in tools allowed within the workspace directory.
|
||||
# Security hooks validate that file paths stay within sdk_cwd.
|
||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
# network isolation (unshare --net) instead.
|
||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"AskUserQuestion",
|
||||
]
|
||||
|
||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
||||
BLOCKED_TOOLS = {
|
||||
*SDK_DISALLOWED_TOOLS,
|
||||
"bash",
|
||||
"shell",
|
||||
"exec",
|
||||
"terminal",
|
||||
"command",
|
||||
}
|
||||
|
||||
# Tools allowed only when their path argument stays within the SDK workspace.
|
||||
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
||||
# files, then reads them back) and for workspace file operations.
|
||||
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
||||
|
||||
# Dangerous patterns in tool inputs
|
||||
DANGEROUS_PATTERNS = [
|
||||
r"sudo",
|
||||
r"rm\s+-rf",
|
||||
r"dd\s+if=",
|
||||
r"/etc/passwd",
|
||||
r"/etc/shadow",
|
||||
r"chmod\s+777",
|
||||
r"curl\s+.*\|.*sh",
|
||||
r"wget\s+.*\|.*sh",
|
||||
r"eval\s*\(",
|
||||
r"exec\s*\(",
|
||||
r"__import__",
|
||||
r"os\.system",
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
# Include MCP tools, the MCP Read tool for oversized results,
|
||||
# and SDK built-in file tools for workspace operations.
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
467
autogpt_platform/backend/backend/copilot/sdk/transcript.py
Normal file
467
autogpt_platform/backend/backend/copilot/sdk/transcript.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""JSONL transcript management for stateless multi-turn resume.
|
||||
|
||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
||||
next turn we download the transcript, write it to a temp file, and pass
|
||||
``--resume`` so the CLI can reconstruct the full conversation.
|
||||
|
||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
|
||||
|
||||
# Entry types that can be safely removed from the transcript without breaking
|
||||
# the parentUuid conversation tree that ``--resume`` relies on.
|
||||
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
|
||||
# - file-history-snapshot: undo tracking metadata
|
||||
# - queue-operation: internal queue bookkeeping
|
||||
# - summary: session summaries
|
||||
# - pr-link: PR link metadata
|
||||
STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress stripping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def strip_progress_entries(content: str) -> str:
|
||||
"""Remove progress/metadata entries from a JSONL transcript.
|
||||
|
||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
||||
Typically reduces transcript size by ~30%.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
entries: list[dict] = []
|
||||
for line in lines:
|
||||
try:
|
||||
entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
# Keep unparseable lines as-is (safety)
|
||||
entries.append({"_raw": line})
|
||||
|
||||
stripped_uuids: set[str] = set()
|
||||
uuid_to_parent: dict[str, str] = {}
|
||||
kept: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
if "_raw" in entry:
|
||||
kept.append(entry)
|
||||
continue
|
||||
uid = entry.get("uuid", "")
|
||||
parent = entry.get("parentUuid", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
||||
if uid:
|
||||
uuid_to_parent[uid] = parent
|
||||
|
||||
if entry_type in STRIPPABLE_TYPES:
|
||||
if uid:
|
||||
stripped_uuids.add(uid)
|
||||
else:
|
||||
kept.append(entry)
|
||||
|
||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
|
||||
result_lines: list[str] = []
|
||||
for entry in kept:
|
||||
if "_raw" in entry:
|
||||
result_lines.append(entry["_raw"])
|
||||
else:
|
||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
|
||||
return "\n".join(result_lines) + "\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_transcript_file(transcript_path: str) -> str | None:
|
||||
"""Read a JSONL transcript file from disk.
|
||||
|
||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
||||
or only contains metadata (≤2 lines with no conversation messages).
|
||||
"""
|
||||
if not transcript_path or not os.path.isfile(transcript_path):
|
||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(transcript_path) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
)
|
||||
return content
|
||||
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
"""Sanitize an ID for safe use in file paths.
|
||||
|
||||
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
|
||||
everything else and truncate to *max_len* so the result cannot introduce
|
||||
path separators or other special characters.
|
||||
"""
|
||||
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
|
||||
return cleaned or "unknown"
|
||||
|
||||
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
else:
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
cwd: str,
|
||||
) -> str | None:
|
||||
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
|
||||
|
||||
The file lives in the session working directory so it is cleaned up
|
||||
automatically when the session ends.
|
||||
|
||||
Returns the absolute path to the file, or ``None`` on failure.
|
||||
"""
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
return None
|
||||
|
||||
try:
|
||||
os.makedirs(real_cwd, exist_ok=True)
|
||||
safe_id = _sanitize_id(session_id, max_len=8)
|
||||
jsonl_path = os.path.realpath(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def validate_transcript(content: str | None) -> bool:
|
||||
"""Check that a transcript has actual conversation messages.
|
||||
|
||||
A valid transcript for resume needs at least one user message and one
|
||||
assistant message (not just queue-operation / file-history-snapshot
|
||||
metadata).
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
has_user = False
|
||||
has_assistant = False
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
msg_type = entry.get("type")
|
||||
if msg_type == "user":
|
||||
has_user = True
|
||||
elif msg_type == "assistant":
|
||||
has_assistant = True
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
return has_user and has_assistant
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bucket storage (GCS / local via WorkspaceStorageBackend)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
||||
|
||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
||||
"""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.jsonl",
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
||||
``local://workspace_id/file_id/filename``. Since we use deterministic
|
||||
arguments we can reconstruct the same path for download/delete without
|
||||
having stored the return value.
|
||||
"""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
return f"gcs://{backend.bucket_name}/{blob}"
|
||||
else:
|
||||
# LocalWorkspaceStorage returns local://{relative_path}
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
new_size = len(encoded)
|
||||
|
||||
# Check existing transcript size to avoid overwriting newer with older
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
try:
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No existing transcript or retrieval error — proceed with upload
|
||||
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
# Wrapped in try/except so a metadata write failure doesn't orphan
|
||||
# the already-uploaded transcript — the next turn will just fall back
|
||||
# to full gap fill (msg_count=0).
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
await storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
||||
255
autogpt_platform/backend/backend/copilot/sdk/transcript_test.py
Normal file
255
autogpt_platform/backend/backend/copilot/sdk/transcript_test.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
read_transcript_file,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
|
||||
|
||||
def _make_jsonl(*entries: dict) -> str:
|
||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
METADATA_LINE = {"type": "queue-operation", "subtype": "create"}
|
||||
FILE_HISTORY = {"type": "file-history-snapshot", "files": []}
|
||||
USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
||||
ASST_MSG = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {"role": "assistant", "content": "hello"},
|
||||
}
|
||||
PROGRESS_ENTRY = {
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"parentUuid": "u1",
|
||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
||||
}
|
||||
|
||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
||||
|
||||
|
||||
# --- read_transcript_file ---
|
||||
|
||||
|
||||
class TestReadTranscriptFile:
|
||||
def test_returns_content_for_valid_file(self, tmp_path):
|
||||
path = tmp_path / "session.jsonl"
|
||||
path.write_text(VALID_TRANSCRIPT)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
assert "user" in result
|
||||
|
||||
def test_returns_none_for_missing_file(self):
|
||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
||||
|
||||
def test_returns_none_for_empty_path(self):
|
||||
assert read_transcript_file("") is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.jsonl"
|
||||
path.write_text("")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
path = tmp_path / "meta.jsonl"
|
||||
path.write_text(content)
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
||||
path = tmp_path / "bad.jsonl"
|
||||
path.write_text("not json\n{}\n{}\n")
|
||||
assert read_transcript_file(str(path)) is None
|
||||
|
||||
def test_no_size_limit(self, tmp_path):
|
||||
"""Large files are accepted — bucket storage has no size limit."""
|
||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
||||
path = tmp_path / "big.jsonl"
|
||||
path.write_text(content)
|
||||
result = read_transcript_file(str(path))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# --- write_transcript_to_tempfile ---
|
||||
|
||||
|
||||
class TestWriteTranscriptToTempfile:
|
||||
"""Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check."""
|
||||
|
||||
def test_writes_file_and_returns_path(self):
|
||||
cwd = "/tmp/copilot-test-write"
|
||||
try:
|
||||
result = write_transcript_to_tempfile(
|
||||
VALID_TRANSCRIPT, "sess-1234-abcd", cwd
|
||||
)
|
||||
assert result is not None
|
||||
assert os.path.isfile(result)
|
||||
assert result.endswith(".jsonl")
|
||||
with open(result) as f:
|
||||
assert f.read() == VALID_TRANSCRIPT
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(cwd, ignore_errors=True)
|
||||
|
||||
def test_creates_parent_directory(self):
|
||||
cwd = "/tmp/copilot-test-mkdir"
|
||||
try:
|
||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
||||
assert result is not None
|
||||
assert os.path.isdir(cwd)
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(cwd, ignore_errors=True)
|
||||
|
||||
def test_uses_session_id_prefix(self):
|
||||
cwd = "/tmp/copilot-test-prefix"
|
||||
try:
|
||||
result = write_transcript_to_tempfile(
|
||||
VALID_TRANSCRIPT, "abcdef12-rest", cwd
|
||||
)
|
||||
assert result is not None
|
||||
assert "abcdef12" in os.path.basename(result)
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(cwd, ignore_errors=True)
|
||||
|
||||
def test_rejects_cwd_outside_sandbox(self, tmp_path):
|
||||
cwd = str(tmp_path / "not-copilot")
|
||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- validate_transcript ---
|
||||
|
||||
|
||||
class TestValidateTranscript:
|
||||
def test_valid_transcript(self):
|
||||
assert validate_transcript(VALID_TRANSCRIPT) is True
|
||||
|
||||
def test_none_content(self):
|
||||
assert validate_transcript(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert validate_transcript("") is False
|
||||
|
||||
def test_metadata_only(self):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_user_only_no_assistant(self):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_assistant_only_no_user(self):
|
||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
||||
assert validate_transcript(content) is False
|
||||
|
||||
def test_invalid_json_returns_false(self):
|
||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
||||
|
||||
|
||||
# --- strip_progress_entries ---
|
||||
|
||||
|
||||
class TestStripProgressEntries:
|
||||
def test_strips_all_strippable_types(self):
|
||||
"""All STRIPPABLE_TYPES are removed from the output."""
|
||||
entries = [
|
||||
USER_MSG,
|
||||
{"type": "progress", "uuid": "p1", "parentUuid": "u1"},
|
||||
{"type": "file-history-snapshot", "files": []},
|
||||
{"type": "queue-operation", "subtype": "create"},
|
||||
{"type": "summary", "text": "..."},
|
||||
{"type": "pr-link", "url": "..."},
|
||||
ASST_MSG,
|
||||
]
|
||||
result = strip_progress_entries(_make_jsonl(*entries))
|
||||
result_types = {json.loads(line)["type"] for line in result.strip().split("\n")}
|
||||
assert result_types == {"user", "assistant"}
|
||||
for stype in STRIPPABLE_TYPES:
|
||||
assert stype not in result_types
|
||||
|
||||
def test_reparents_children_of_stripped_entries(self):
|
||||
"""An assistant message whose parent is a progress entry gets reparented."""
|
||||
progress = {
|
||||
"type": "progress",
|
||||
"uuid": "p1",
|
||||
"parentUuid": "u1",
|
||||
"data": {"type": "bash_progress"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p1", # Points to progress
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, progress, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
|
||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
||||
# Should be reparented to u1 (the user message)
|
||||
assert asst_entry["parentUuid"] == "u1"
|
||||
|
||||
def test_reparents_through_chain(self):
|
||||
"""Reparenting walks through multiple stripped entries."""
|
||||
p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
||||
p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"}
|
||||
p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "p3", # 3 levels deep
|
||||
"message": {"role": "assistant", "content": "done"},
|
||||
}
|
||||
content = _make_jsonl(USER_MSG, p1, p2, p3, asst)
|
||||
result = strip_progress_entries(content)
|
||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
||||
|
||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
||||
assert asst_entry["parentUuid"] == "u1"
|
||||
|
||||
def test_preserves_non_strippable_entries(self):
|
||||
"""User, assistant, and system entries are preserved."""
|
||||
system = {"type": "system", "uuid": "s1", "message": "prompt"}
|
||||
content = _make_jsonl(system, USER_MSG, ASST_MSG)
|
||||
result = strip_progress_entries(content)
|
||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
||||
assert result_types == ["system", "user", "assistant"]
|
||||
|
||||
def test_empty_input(self):
|
||||
result = strip_progress_entries("")
|
||||
# Should return just a newline (empty content stripped)
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_no_strippable_entries(self):
|
||||
"""When there's nothing to strip, output matches input structure."""
|
||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
||||
result = strip_progress_entries(content)
|
||||
result_lines = result.strip().split("\n")
|
||||
assert len(result_lines) == 2
|
||||
|
||||
def test_handles_entries_without_uuid(self):
|
||||
"""Entries without uuid field are handled gracefully."""
|
||||
no_uuid = {"type": "queue-operation", "subtype": "create"}
|
||||
content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG)
|
||||
result = strip_progress_entries(content)
|
||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
||||
# queue-operation is strippable
|
||||
assert "queue-operation" not in result_types
|
||||
assert "user" in result_types
|
||||
assert "assistant" in result_types
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user