mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 08:14:58 -05:00
Compare commits
8 Commits
feat/copil
...
ntindle/go
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7705427bb | ||
|
|
201ec5aa3a | ||
|
|
2b8134a711 | ||
|
|
90b3b5ba16 | ||
|
|
f4f81bc4fc | ||
|
|
c5abc01f25 | ||
|
|
8b7053c1de | ||
|
|
e00c1202ad |
@@ -5,13 +5,42 @@
|
|||||||
!docs/
|
!docs/
|
||||||
|
|
||||||
# Platform - Libs
|
# Platform - Libs
|
||||||
!autogpt_platform/autogpt_libs/
|
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||||
|
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||||
|
!autogpt_platform/autogpt_libs/poetry.lock
|
||||||
|
!autogpt_platform/autogpt_libs/README.md
|
||||||
|
|
||||||
# Platform - Backend
|
# Platform - Backend
|
||||||
!autogpt_platform/backend/
|
!autogpt_platform/backend/backend/
|
||||||
|
!autogpt_platform/backend/test/e2e_test_data.py
|
||||||
|
!autogpt_platform/backend/migrations/
|
||||||
|
!autogpt_platform/backend/schema.prisma
|
||||||
|
!autogpt_platform/backend/pyproject.toml
|
||||||
|
!autogpt_platform/backend/poetry.lock
|
||||||
|
!autogpt_platform/backend/README.md
|
||||||
|
!autogpt_platform/backend/.env
|
||||||
|
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||||
|
|
||||||
|
# Platform - Market
|
||||||
|
!autogpt_platform/market/market/
|
||||||
|
!autogpt_platform/market/scripts.py
|
||||||
|
!autogpt_platform/market/schema.prisma
|
||||||
|
!autogpt_platform/market/pyproject.toml
|
||||||
|
!autogpt_platform/market/poetry.lock
|
||||||
|
!autogpt_platform/market/README.md
|
||||||
|
|
||||||
# Platform - Frontend
|
# Platform - Frontend
|
||||||
!autogpt_platform/frontend/
|
!autogpt_platform/frontend/src/
|
||||||
|
!autogpt_platform/frontend/public/
|
||||||
|
!autogpt_platform/frontend/scripts/
|
||||||
|
!autogpt_platform/frontend/package.json
|
||||||
|
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
|
!autogpt_platform/frontend/tsconfig.json
|
||||||
|
!autogpt_platform/frontend/README.md
|
||||||
|
## config
|
||||||
|
!autogpt_platform/frontend/*.config.*
|
||||||
|
!autogpt_platform/frontend/.env.*
|
||||||
|
!autogpt_platform/frontend/.env
|
||||||
|
|
||||||
# Classic - AutoGPT
|
# Classic - AutoGPT
|
||||||
!classic/original_autogpt/autogpt/
|
!classic/original_autogpt/autogpt/
|
||||||
@@ -35,38 +64,6 @@
|
|||||||
# Classic - Frontend
|
# Classic - Frontend
|
||||||
!classic/frontend/build/web/
|
!classic/frontend/build/web/
|
||||||
|
|
||||||
# Explicitly re-ignore unwanted files from whitelisted directories
|
# Explicitly re-ignore some folders
|
||||||
# Note: These patterns MUST come after the whitelist rules to take effect
|
.*
|
||||||
|
**/__pycache__
|
||||||
# Hidden files and directories (but keep frontend .env files needed for build)
|
|
||||||
**/.*
|
|
||||||
!autogpt_platform/frontend/.env
|
|
||||||
!autogpt_platform/frontend/.env.default
|
|
||||||
!autogpt_platform/frontend/.env.production
|
|
||||||
|
|
||||||
# Python artifacts
|
|
||||||
**/__pycache__/
|
|
||||||
**/*.pyc
|
|
||||||
**/*.pyo
|
|
||||||
**/.venv/
|
|
||||||
**/.ruff_cache/
|
|
||||||
**/.pytest_cache/
|
|
||||||
**/.coverage
|
|
||||||
**/htmlcov/
|
|
||||||
|
|
||||||
# Node artifacts
|
|
||||||
**/node_modules/
|
|
||||||
**/.next/
|
|
||||||
**/storybook-static/
|
|
||||||
**/playwright-report/
|
|
||||||
**/test-results/
|
|
||||||
|
|
||||||
# Build artifacts
|
|
||||||
**/dist/
|
|
||||||
**/build/
|
|
||||||
!autogpt_platform/frontend/src/**/build/
|
|
||||||
**/target/
|
|
||||||
|
|
||||||
# Logs and temp files
|
|
||||||
**/*.log
|
|
||||||
**/*.tmp
|
|
||||||
|
|||||||
249
.github/workflows/platform-frontend-ci.yml
vendored
249
.github/workflows/platform-frontend-ci.yml
vendored
@@ -26,6 +26,7 @@ jobs:
|
|||||||
setup:
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
components-changed: ${{ steps.filter.outputs.components }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -40,17 +41,28 @@ jobs:
|
|||||||
components:
|
components:
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Set up Node.js
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Install dependencies to populate cache
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Generate cache key
|
||||||
|
id: cache-key
|
||||||
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Cache dependencies
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
@@ -61,15 +73,22 @@ jobs:
|
|||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Set up Node.js
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Restore dependencies cache
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -92,15 +111,22 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Set up Node.js
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Restore dependencies cache
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -115,8 +141,10 @@ jobs:
|
|||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
e2e_test:
|
e2e_test:
|
||||||
name: end-to-end tests
|
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
|
needs: setup
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -124,11 +152,19 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Platform - Copy default supabase .env
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Copy default supabase .env
|
||||||
run: |
|
run: |
|
||||||
cp ../.env.default ../.env
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
- name: Copy backend .env and set OpenAI API key
|
||||||
run: |
|
run: |
|
||||||
cp ../backend/.env.default ../backend/.env
|
cp ../backend/.env.default ../backend/.env
|
||||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||||
@@ -136,125 +172,77 @@ jobs:
|
|||||||
# Used by E2E test data script to generate embeddings for approved store agents
|
# Used by E2E test data script to generate embeddings for approved store agents
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
- name: Set up Platform - Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
with:
|
|
||||||
driver: docker-container
|
|
||||||
driver-opts: network=host
|
|
||||||
|
|
||||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
- name: Cache Docker layers
|
||||||
uses: crazy-max/ghaction-github-runtime@v3
|
|
||||||
|
|
||||||
- name: Set up Platform - Build Docker images (with cache)
|
|
||||||
working-directory: autogpt_platform
|
|
||||||
run: |
|
|
||||||
pip install pyyaml
|
|
||||||
|
|
||||||
# Resolve extends and generate a flat compose file that bake can understand
|
|
||||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
|
||||||
|
|
||||||
# Add cache configuration to the resolved compose file
|
|
||||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
|
||||||
--source docker-compose.resolved.yml \
|
|
||||||
--cache-from "type=gha" \
|
|
||||||
--cache-to "type=gha,mode=max" \
|
|
||||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
|
||||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
|
||||||
--git-ref "${{ github.ref }}"
|
|
||||||
|
|
||||||
# Build with bake using the resolved compose file (now includes cache config)
|
|
||||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
|
||||||
env:
|
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
|
||||||
|
|
||||||
- name: Set up tests - Cache E2E test data
|
|
||||||
id: e2e-data-cache
|
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: /tmp/e2e_test_data.sql
|
path: /tmp/.buildx-cache
|
||||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-buildx-frontend-test-
|
||||||
|
|
||||||
- name: Set up Platform - Start Supabase DB + Auth
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||||
echo "Waiting for database to be ready..."
|
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
|
||||||
echo "Waiting for auth service to be ready..."
|
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
|
||||||
|
|
||||||
- name: Set up Platform - Run migrations
|
|
||||||
run: |
|
|
||||||
echo "Running migrations..."
|
|
||||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
|
||||||
echo "✅ Migrations completed"
|
|
||||||
env:
|
env:
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
DOCKER_BUILDKIT: 1
|
||||||
|
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||||
|
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||||
|
|
||||||
- name: Set up tests - Load cached E2E test data
|
- name: Move cache
|
||||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
|
||||||
run: |
|
run: |
|
||||||
echo "✅ Found cached E2E test data, restoring..."
|
rm -rf /tmp/.buildx-cache
|
||||||
{
|
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||||
echo "SET session_replication_role = 'replica';"
|
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||||
cat /tmp/e2e_test_data.sql
|
fi
|
||||||
echo "SET session_replication_role = 'origin';"
|
|
||||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
|
||||||
# Refresh materialized views after restore
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
|
||||||
|
|
||||||
echo "✅ E2E test data restored from cache"
|
- name: Wait for services to be ready
|
||||||
|
|
||||||
- name: Set up Platform - Start (all other services)
|
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
|
||||||
echo "Waiting for rest_server to be ready..."
|
echo "Waiting for rest_server to be ready..."
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
env:
|
echo "Waiting for database to be ready..."
|
||||||
NEXT_PUBLIC_PW_TEST: true
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||||
|
|
||||||
- name: Set up tests - Create E2E test data
|
- name: Create E2E test data
|
||||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
run: |
|
||||||
echo "Creating E2E test data..."
|
echo "Creating E2E test data..."
|
||||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
# First try to run the script from inside the container
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||||
echo "❌ E2E test data creation failed!"
|
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||||
exit 1
|
echo "❌ E2E test data creation failed!"
|
||||||
}
|
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
else
|
||||||
|
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||||
|
# Copy the script into the container and run it
|
||||||
|
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||||
|
echo "❌ Failed to copy script to container"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||||
|
echo "❌ E2E test data creation failed!"
|
||||||
|
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
fi
|
||||||
|
|
||||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
- name: Restore dependencies cache
|
||||||
echo "Dumping database for cache..."
|
uses: actions/cache@v5
|
||||||
{
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
pg_dump -U postgres --data-only --column-inserts \
|
|
||||||
--table='auth.users' postgres
|
|
||||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
|
||||||
pg_dump -U postgres --data-only --column-inserts \
|
|
||||||
--schema=platform \
|
|
||||||
--exclude-table='platform._prisma_migrations' \
|
|
||||||
--exclude-table='platform.apscheduler_jobs' \
|
|
||||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
|
||||||
postgres
|
|
||||||
} > /tmp/e2e_test_data.sql
|
|
||||||
|
|
||||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
|
||||||
|
|
||||||
- name: Set up tests - Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up tests - Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
path: ~/.pnpm-store
|
||||||
cache: "pnpm"
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Set up tests - Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Set up tests - Install browser 'chromium'
|
- name: Install Browser 'chromium'
|
||||||
run: pnpm playwright install --with-deps chromium
|
run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
@@ -281,7 +269,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
run: docker compose -f ../docker-compose.yml logs
|
||||||
|
|
||||||
integration_test:
|
integration_test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -293,15 +281,22 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Set up Node.js
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set up Node
|
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Restore dependencies cache
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|||||||
@@ -1,195 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Add cache configuration to a resolved docker-compose file for all services
|
|
||||||
that have a build key, and ensure image names match what docker compose expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BRANCH = "dev"
|
|
||||||
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Add cache config to a resolved compose file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--source",
|
|
||||||
required=True,
|
|
||||||
help="Source compose file to read (should be output of `docker compose config`)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache-from",
|
|
||||||
default="type=gha",
|
|
||||||
help="Cache source configuration",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cache-to",
|
|
||||||
default="type=gha,mode=max",
|
|
||||||
help="Cache destination configuration",
|
|
||||||
)
|
|
||||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
|
||||||
parser.add_argument(
|
|
||||||
f"--{component}-hash",
|
|
||||||
default="",
|
|
||||||
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--git-ref",
|
|
||||||
default="",
|
|
||||||
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
|
||||||
git_ref_scope = ""
|
|
||||||
if args.git_ref:
|
|
||||||
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
|
||||||
|
|
||||||
with open(args.source, "r") as f:
|
|
||||||
compose = yaml.safe_load(f)
|
|
||||||
|
|
||||||
# Get project name from compose file or default
|
|
||||||
project_name = compose.get("name", "autogpt_platform")
|
|
||||||
|
|
||||||
def get_image_name(dockerfile: str, target: str) -> str:
|
|
||||||
"""Generate image name based on Dockerfile folder and build target."""
|
|
||||||
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
|
||||||
if len(dockerfile_parts) >= 2:
|
|
||||||
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
|
||||||
else:
|
|
||||||
folder_name = "app"
|
|
||||||
return f"{project_name}-{folder_name}:{target}"
|
|
||||||
|
|
||||||
def get_build_key(dockerfile: str, target: str) -> str:
|
|
||||||
"""Generate a unique key for a Dockerfile+target combination."""
|
|
||||||
return f"{dockerfile}:{target}"
|
|
||||||
|
|
||||||
def get_component(dockerfile: str) -> str | None:
|
|
||||||
"""Get component name (frontend/backend) from dockerfile path."""
|
|
||||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
|
||||||
if component in dockerfile:
|
|
||||||
return component
|
|
||||||
return None
|
|
||||||
|
|
||||||
# First pass: collect all services with build configs and identify duplicates
|
|
||||||
# Track which (dockerfile, target) combinations we've seen
|
|
||||||
build_key_to_first_service: dict[str, str] = {}
|
|
||||||
services_to_build: list[str] = []
|
|
||||||
services_to_dedupe: list[str] = []
|
|
||||||
|
|
||||||
for service_name, service_config in compose.get("services", {}).items():
|
|
||||||
if "build" not in service_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
build_config = service_config["build"]
|
|
||||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
|
||||||
target = build_config.get("target", "default")
|
|
||||||
build_key = get_build_key(dockerfile, target)
|
|
||||||
|
|
||||||
if build_key not in build_key_to_first_service:
|
|
||||||
# First service with this build config - it will do the actual build
|
|
||||||
build_key_to_first_service[build_key] = service_name
|
|
||||||
services_to_build.append(service_name)
|
|
||||||
else:
|
|
||||||
# Duplicate - will just use the image from the first service
|
|
||||||
services_to_dedupe.append(service_name)
|
|
||||||
|
|
||||||
# Second pass: configure builds and deduplicate
|
|
||||||
modified_services = []
|
|
||||||
for service_name, service_config in compose.get("services", {}).items():
|
|
||||||
if "build" not in service_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
build_config = service_config["build"]
|
|
||||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
|
||||||
target = build_config.get("target", "latest")
|
|
||||||
image_name = get_image_name(dockerfile, target)
|
|
||||||
|
|
||||||
# Set image name for all services (needed for both builders and deduped)
|
|
||||||
service_config["image"] = image_name
|
|
||||||
|
|
||||||
if service_name in services_to_dedupe:
|
|
||||||
# Remove build config - this service will use the pre-built image
|
|
||||||
del service_config["build"]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# This service will do the actual build - add cache config
|
|
||||||
cache_from_list = []
|
|
||||||
cache_to_list = []
|
|
||||||
|
|
||||||
component = get_component(dockerfile)
|
|
||||||
if not component:
|
|
||||||
# Skip services that don't clearly match frontend/backend
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the hash for this component
|
|
||||||
component_hash = getattr(args, f"{component}_hash")
|
|
||||||
|
|
||||||
# Scope format: platform-{component}-{target}-{hash|ref}
|
|
||||||
# Example: platform-backend-server-abc123
|
|
||||||
|
|
||||||
if "type=gha" in args.cache_from:
|
|
||||||
# 1. Primary: exact hash match (most specific)
|
|
||||||
if component_hash:
|
|
||||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
|
||||||
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
|
||||||
|
|
||||||
# 2. Fallback: branch-based cache
|
|
||||||
if git_ref_scope:
|
|
||||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
|
||||||
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
|
||||||
|
|
||||||
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
|
||||||
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
|
||||||
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
|
||||||
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
|
||||||
|
|
||||||
if "type=gha" in args.cache_to:
|
|
||||||
# Write to both hash-based and branch-based scopes
|
|
||||||
if component_hash:
|
|
||||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
|
||||||
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
|
||||||
|
|
||||||
if git_ref_scope:
|
|
||||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
|
||||||
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
|
||||||
|
|
||||||
# Ensure we have at least one cache source/target
|
|
||||||
if not cache_from_list:
|
|
||||||
cache_from_list.append(args.cache_from)
|
|
||||||
if not cache_to_list:
|
|
||||||
cache_to_list.append(args.cache_to)
|
|
||||||
|
|
||||||
build_config["cache_from"] = cache_from_list
|
|
||||||
build_config["cache_to"] = cache_to_list
|
|
||||||
modified_services.append(service_name)
|
|
||||||
|
|
||||||
# Write back to the same file
|
|
||||||
with open(args.source, "w") as f:
|
|
||||||
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
|
||||||
|
|
||||||
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
|
||||||
for svc in modified_services:
|
|
||||||
svc_config = compose["services"][svc]
|
|
||||||
build_cfg = svc_config.get("build", {})
|
|
||||||
cache_from_list = build_cfg.get("cache_from", ["none"])
|
|
||||||
cache_to_list = build_cfg.get("cache_to", ["none"])
|
|
||||||
print(f" - {svc}")
|
|
||||||
print(f" image: {svc_config.get('image', 'N/A')}")
|
|
||||||
print(f" cache_from: {cache_from_list}")
|
|
||||||
print(f" cache_to: {cache_to_list}")
|
|
||||||
if services_to_dedupe:
|
|
||||||
print(
|
|
||||||
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
|
||||||
)
|
|
||||||
for svc in services_to_dedupe:
|
|
||||||
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -45,11 +45,6 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
### Branching Strategy
|
|
||||||
|
|
||||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
|
||||||
- **`master`** is the production branch. Only used for production releases.
|
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR against the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
|
|||||||
169
autogpt_platform/autogpt_libs/poetry.lock
generated
169
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -448,61 +448,61 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cryptography"
|
name = "cryptography"
|
||||||
version = "46.0.5"
|
version = "46.0.4"
|
||||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"},
|
{file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"},
|
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"},
|
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"},
|
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"},
|
{file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"},
|
||||||
{file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"},
|
{file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"},
|
{file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"},
|
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"},
|
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"},
|
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"},
|
{file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"},
|
||||||
{file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"},
|
{file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"},
|
{file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"},
|
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"},
|
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"},
|
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"},
|
{file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"},
|
||||||
{file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"},
|
{file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"},
|
||||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
{file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"},
|
||||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"},
|
||||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"},
|
||||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
|
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"},
|
||||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
|
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"},
|
||||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
|
{file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"},
|
||||||
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
{file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -516,7 +516,7 @@ nox = ["nox[uv] (>=2024.4.15)"]
|
|||||||
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
||||||
sdist = ["build (>=1.0.0)"]
|
sdist = ["build (>=1.0.0)"]
|
||||||
ssh = ["bcrypt (>=3.1.5)"]
|
ssh = ["bcrypt (>=3.1.5)"]
|
||||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||||
test-randomorder = ["pytest-randomly"]
|
test-randomorder = ["pytest-randomly"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -570,25 +570,24 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.128.7"
|
version = "0.128.0"
|
||||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"},
|
{file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"},
|
||||||
{file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"},
|
{file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
annotated-doc = ">=0.0.2"
|
annotated-doc = ">=0.0.2"
|
||||||
pydantic = ">=2.7.0"
|
pydantic = ">=2.7.0"
|
||||||
starlette = ">=0.40.0,<1.0.0"
|
starlette = ">=0.40.0,<0.51.0"
|
||||||
typing-extensions = ">=4.8.0"
|
typing-extensions = ">=4.8.0"
|
||||||
typing-inspection = ">=0.4.2"
|
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"]
|
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
|
|
||||||
@@ -1063,14 +1062,14 @@ urllib3 = ">=1.26.0,<3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "launchdarkly-server-sdk"
|
name = "launchdarkly-server-sdk"
|
||||||
version = "9.15.0"
|
version = "9.14.1"
|
||||||
description = "LaunchDarkly SDK for Python"
|
description = "LaunchDarkly SDK for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
||||||
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -1479,14 +1478,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "postgrest"
|
name = "postgrest"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
{file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"},
|
||||||
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
{file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2249,14 +2248,14 @@ cli = ["click (>=5.0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "realtime"
|
name = "realtime"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
{file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"},
|
||||||
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
{file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2437,14 +2436,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "storage3"
|
name = "storage3"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Supabase Storage client for Python."
|
description = "Supabase Storage client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
{file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"},
|
||||||
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
{file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2488,35 +2487,35 @@ python-dateutil = ">=2.6.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase"
|
name = "supabase"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Supabase client for Python."
|
description = "Supabase client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
{file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"},
|
||||||
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
{file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
httpx = ">=0.26,<0.29"
|
httpx = ">=0.26,<0.29"
|
||||||
postgrest = "2.28.0"
|
postgrest = "2.27.2"
|
||||||
realtime = "2.28.0"
|
realtime = "2.27.2"
|
||||||
storage3 = "2.28.0"
|
storage3 = "2.27.2"
|
||||||
supabase-auth = "2.28.0"
|
supabase-auth = "2.27.2"
|
||||||
supabase-functions = "2.28.0"
|
supabase-functions = "2.27.2"
|
||||||
yarl = ">=1.22.0"
|
yarl = ">=1.22.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-auth"
|
name = "supabase-auth"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Python Client Library for Supabase Auth"
|
description = "Python Client Library for Supabase Auth"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
{file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"},
|
||||||
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
{file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2526,14 +2525,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-functions"
|
name = "supabase-functions"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Library for Supabase Functions"
|
description = "Library for Supabase Functions"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
{file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"},
|
||||||
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
{file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2912,4 +2911,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<4.0"
|
python-versions = ">=3.10,<4.0"
|
||||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ python = ">=3.10,<4.0"
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.7"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.15.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.28.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
# ============================ DEPENDENCY BUILDER ============================ #
|
|
||||||
|
|
||||||
FROM debian:13-slim AS builder
|
FROM debian:13-slim AS builder
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
@@ -53,9 +51,7 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# ============================== BACKEND SERVER ============================== #
|
FROM debian:13-slim AS server_dependencies
|
||||||
|
|
||||||
FROM debian:13-slim AS server
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -66,21 +62,16 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
RUN apt-get update && apt-get install -y \
|
||||||
# for the bash_exec MCP tool.
|
|
||||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
imagemagick \
|
imagemagick \
|
||||||
jq \
|
|
||||||
ripgrep \
|
|
||||||
tree \
|
|
||||||
bubblewrap \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy only necessary files from builder
|
||||||
|
COPY --from=builder /app /app
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
# Copy Node.js installation for Prisma
|
# Copy Node.js installation for Prisma
|
||||||
@@ -90,54 +81,30 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm
|
|||||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
# Copy only the .venv from builder (not the entire /app directory)
|
|
||||||
# The .venv includes the generated Prisma client
|
|
||||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
|
|
||||||
# Copy dependency files + autogpt_libs (path dependency)
|
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
RUN mkdir -p /app/autogpt_platform/backend
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
|
||||||
|
|
||||||
# Copy backend code + docs (for Copilot docs search)
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
COPY autogpt_platform/backend ./
|
|
||||||
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||||
|
|
||||||
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
|
FROM server_dependencies AS migrate
|
||||||
|
|
||||||
|
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||||
|
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||||
|
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||||
|
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||||
|
|
||||||
|
FROM server_dependencies AS server
|
||||||
|
|
||||||
|
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||||
COPY docs /app/docs
|
COPY docs /app/docs
|
||||||
RUN poetry install --no-ansi --only-root
|
RUN poetry install --no-ansi --only-root
|
||||||
|
|
||||||
ENV PORT=8000
|
ENV PORT=8000
|
||||||
|
|
||||||
CMD ["poetry", "run", "rest"]
|
CMD ["poetry", "run", "rest"]
|
||||||
|
|
||||||
# =============================== DB MIGRATOR =============================== #
|
|
||||||
|
|
||||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
|
||||||
FROM debian:13-slim AS migrate
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3.13 \
|
|
||||||
python3-pip \
|
|
||||||
ca-certificates \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy Node.js from builder (needed for Prisma CLI)
|
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
|
||||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
|
||||||
|
|
||||||
# Copy Prisma binaries
|
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
|
||||||
|
|
||||||
# Install prisma-client-py directly (much smaller than copying full venv)
|
|
||||||
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
|
||||||
|
|
||||||
COPY autogpt_platform/backend/schema.prisma ./
|
|
||||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
|
||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
|
||||||
COPY autogpt_platform/backend/migrations ./migrations
|
|
||||||
|
|||||||
@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
max_context_messages: int = Field(
|
||||||
max_retries: int = Field(
|
default=50, ge=1, le=200, description="Maximum context messages"
|
||||||
default=3,
|
|
||||||
description="Max retries for fallback path (SDK handles retries internally)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -92,31 +93,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Claude Agent SDK Configuration
|
|
||||||
use_claude_agent_sdk: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Use Claude Agent SDK for chat completions",
|
|
||||||
)
|
|
||||||
claude_agent_model: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Model for the Claude Agent SDK path. If None, derives from "
|
|
||||||
"the `model` field by stripping the OpenRouter provider prefix.",
|
|
||||||
)
|
|
||||||
claude_agent_max_buffer_size: int = Field(
|
|
||||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
|
||||||
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
|
|
||||||
"Increase if tool outputs exceed the limit.",
|
|
||||||
)
|
|
||||||
claude_agent_max_subtasks: int = Field(
|
|
||||||
default=10,
|
|
||||||
description="Max number of sub-agent Tasks the SDK can spawn per session.",
|
|
||||||
)
|
|
||||||
claude_agent_use_resume: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Use --resume for multi-turn conversations instead of "
|
|
||||||
"history compression. Falls back to compression when unavailable.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
# Extended thinking configuration for Claude models
|
||||||
thinking_enabled: bool = Field(
|
thinking_enabled: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
@@ -162,17 +138,6 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("use_claude_agent_sdk", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_use_claude_agent_sdk(cls, v):
|
|
||||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
|
||||||
# Check environment variable - default to True if not set
|
|
||||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
|
||||||
if env_val:
|
|
||||||
return env_val in ("true", "1", "yes", "on")
|
|
||||||
# Default to True (SDK enabled by default)
|
|
||||||
return True if v is None else v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -334,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
f"Loading session {session_id} from cache: "
|
||||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -377,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
f"Loading session {session_id} from DB: "
|
||||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
f"has_messages={messages is not None}, "
|
||||||
|
f"message_count={len(messages) if messages else 0}, "
|
||||||
|
f"roles={[m.role for m in messages] if messages else []}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -430,9 +433,10 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||||
f"roles={[m['role'] for m in messages_data]}"
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
|
f"start_sequence={existing_message_count}"
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -472,7 +476,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
logger.info(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -489,6 +493,7 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
logger.info(f"Cached session {session_id} from database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -553,40 +558,6 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
|
||||||
"""Atomically append a message to a session and persist it.
|
|
||||||
|
|
||||||
Acquires the session lock, re-fetches the latest session state,
|
|
||||||
appends the message, and saves — preventing message loss when
|
|
||||||
concurrent requests modify the same session.
|
|
||||||
"""
|
|
||||||
lock = await _get_session_lock(session_id)
|
|
||||||
|
|
||||||
async with lock:
|
|
||||||
session = await get_chat_session(session_id)
|
|
||||||
if session is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
session.messages.append(message)
|
|
||||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
|
||||||
session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _save_session_to_db(session, existing_message_count)
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabaseError(
|
|
||||||
f"Failed to persist message to session {session_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _cache_session(session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(user_id: str) -> ChatSession:
|
async def create_chat_session(user_id: str) -> ChatSession:
|
||||||
"""Create a new chat session and persist it.
|
"""Create a new chat session and persist it.
|
||||||
|
|
||||||
@@ -693,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Update title in cache if it exists (instead of invalidating).
|
# Invalidate cache so next fetch gets updated title
|
||||||
# This prevents race conditions where cache invalidation causes
|
|
||||||
# the frontend to see stale DB data while streaming is still in progress.
|
|
||||||
try:
|
try:
|
||||||
cached = await _get_session_from_cache(session_id)
|
redis_key = _get_session_cache_key(session_id)
|
||||||
if cached:
|
async_redis = await get_redis_async()
|
||||||
cached.title = title
|
await async_redis.delete(redis_key)
|
||||||
await _cache_session(cached)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Not critical - title will be correct on next full cache refresh
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
logger.warning(
|
|
||||||
f"Failed to update title in cache for session {session_id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -12,29 +11,19 @@ from fastapi.responses import StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
ChatMessage,
|
from .response_model import StreamFinish, StreamHeartbeat
|
||||||
ChatSession,
|
|
||||||
append_and_save_message,
|
|
||||||
create_chat_session,
|
|
||||||
get_chat_session,
|
|
||||||
get_user_sessions,
|
|
||||||
)
|
|
||||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
|
||||||
from .sdk import service as sdk_service
|
|
||||||
from .tools.models import (
|
from .tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AgentsFoundResponse,
|
AgentsFoundResponse,
|
||||||
BlockDetailsResponse,
|
|
||||||
BlockListResponse,
|
BlockListResponse,
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
@@ -51,7 +40,6 @@ from .tools.models import (
|
|||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from .tracking import track_user_message
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -243,10 +231,6 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
|
||||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
|
||||||
)
|
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -316,9 +300,10 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -327,25 +312,6 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Atomically append user message to session BEFORE creating task to avoid
|
|
||||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
|
||||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
|
||||||
# message loss from concurrent requests.
|
|
||||||
if request.message:
|
|
||||||
message = ChatMessage(
|
|
||||||
role="user" if request.is_user_message else "assistant",
|
|
||||||
content=request.message,
|
|
||||||
)
|
|
||||||
if request.is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
message_length=len(request.message),
|
|
||||||
)
|
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
|
||||||
session = await append_and_save_message(session_id, message)
|
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
@@ -361,7 +327,7 @@ async def stream_chat_post(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -382,47 +348,15 @@ async def stream_chat_post(
|
|||||||
first_chunk_time, ttfc = None, None
|
first_chunk_time, ttfc = None, None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
|
||||||
* 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Choose service based on LaunchDarkly flag (falls back to config default)
|
|
||||||
use_sdk = await is_feature_enabled(
|
|
||||||
Flag.COPILOT_SDK,
|
|
||||||
user_id or "anonymous",
|
|
||||||
default=config.use_claude_agent_sdk,
|
|
||||||
)
|
|
||||||
stream_fn = (
|
|
||||||
sdk_service.stream_chat_completion_sdk
|
|
||||||
if use_sdk
|
|
||||||
else chat_service.stream_chat_completion
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
# Pass message=None since we already added it to the session above
|
|
||||||
async for chunk in stream_fn(
|
|
||||||
session_id,
|
session_id,
|
||||||
None, # Message already in session
|
request.message,
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass session with message already added
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
|
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||||
):
|
):
|
||||||
# Skip duplicate StreamStart — we already published one above
|
|
||||||
if isinstance(chunk, StreamStart):
|
|
||||||
continue
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
if first_chunk_time is None:
|
if first_chunk_time is None:
|
||||||
first_chunk_time = time_module.perf_counter()
|
first_chunk_time = time_module.perf_counter()
|
||||||
@@ -443,7 +377,7 @@ async def stream_chat_post(
|
|||||||
gen_end_time = time_module.perf_counter()
|
gen_end_time = time_module.perf_counter()
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||||
f"task={task_id}, session={session_id}, "
|
f"task={task_id}, session={session_id}, "
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
extra={
|
extra={
|
||||||
@@ -470,17 +404,6 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Publish a StreamError so the frontend can display an error message
|
|
||||||
try:
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
@@ -583,14 +506,8 @@ async def stream_chat_post(
|
|||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Surface error to frontend so it doesn't appear stuck
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
).to_sse()
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -834,6 +751,8 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@@ -1052,7 +971,6 @@ ToolResponseUnion = (
|
|||||||
| AgentSavedResponse
|
| AgentSavedResponse
|
||||||
| ClarificationNeededResponse
|
| ClarificationNeededResponse
|
||||||
| BlockListResponse
|
| BlockListResponse
|
||||||
| BlockDetailsResponse
|
|
||||||
| BlockOutputResponse
|
| BlockOutputResponse
|
||||||
| DocSearchResultsResponse
|
| DocSearchResultsResponse
|
||||||
| DocPageResponse
|
| DocPageResponse
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
"""Claude Agent SDK integration for CoPilot.
|
|
||||||
|
|
||||||
This module provides the integration layer between the Claude Agent SDK
|
|
||||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
|
||||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .service import stream_chat_completion_sdk
|
|
||||||
from .tool_adapter import create_copilot_mcp_server
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"stream_chat_completion_sdk",
|
|
||||||
"create_copilot_mcp_server",
|
|
||||||
]
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
|
||||||
|
|
||||||
This module provides the adapter layer that converts streaming messages from
|
|
||||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
|
||||||
the frontend expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
Message,
|
|
||||||
ResultMessage,
|
|
||||||
SystemMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamFinishStep,
|
|
||||||
StreamStart,
|
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import (
|
|
||||||
MCP_TOOL_PREFIX,
|
|
||||||
pop_pending_tool_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SDKResponseAdapter:
|
|
||||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
|
||||||
|
|
||||||
This class maintains state during a streaming session to properly track
|
|
||||||
text blocks, tool calls, and message lifecycle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message_id: str | None = None):
|
|
||||||
self.message_id = message_id or str(uuid.uuid4())
|
|
||||||
self.text_block_id = str(uuid.uuid4())
|
|
||||||
self.has_started_text = False
|
|
||||||
self.has_ended_text = False
|
|
||||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
|
||||||
self.task_id: str | None = None
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
def set_task_id(self, task_id: str) -> None:
|
|
||||||
"""Set the task ID for reconnection support."""
|
|
||||||
self.task_id = task_id
|
|
||||||
|
|
||||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
|
||||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
|
||||||
responses: list[StreamBaseResponse] = []
|
|
||||||
|
|
||||||
if isinstance(sdk_message, SystemMessage):
|
|
||||||
if sdk_message.subtype == "init":
|
|
||||||
responses.append(
|
|
||||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
|
||||||
)
|
|
||||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
|
||||||
responses.append(StreamStartStep())
|
|
||||||
self.step_open = True
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, AssistantMessage):
|
|
||||||
# After tool results, the SDK sends a new AssistantMessage for the
|
|
||||||
# next LLM turn. Open a new step if the previous one was closed.
|
|
||||||
if not self.step_open:
|
|
||||||
responses.append(StreamStartStep())
|
|
||||||
self.step_open = True
|
|
||||||
|
|
||||||
for block in sdk_message.content:
|
|
||||||
if isinstance(block, TextBlock):
|
|
||||||
if block.text:
|
|
||||||
self._ensure_text_started(responses)
|
|
||||||
responses.append(
|
|
||||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(block, ToolUseBlock):
|
|
||||||
self._end_text_if_open(responses)
|
|
||||||
|
|
||||||
# Strip MCP prefix so frontend sees "find_block"
|
|
||||||
# instead of "mcp__copilot__find_block".
|
|
||||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
|
||||||
)
|
|
||||||
responses.append(
|
|
||||||
StreamToolInputAvailable(
|
|
||||||
toolCallId=block.id,
|
|
||||||
toolName=tool_name,
|
|
||||||
input=block.input,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, UserMessage):
|
|
||||||
# UserMessage carries tool results back from tool execution.
|
|
||||||
content = sdk_message.content
|
|
||||||
blocks = content if isinstance(content, list) else []
|
|
||||||
for block in blocks:
|
|
||||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
|
||||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
|
||||||
tool_name = tool_info.get("name", "unknown")
|
|
||||||
|
|
||||||
# Prefer the stashed full output over the SDK's
|
|
||||||
# (potentially truncated) ToolResultBlock content.
|
|
||||||
# The SDK truncates large results, writing them to disk,
|
|
||||||
# which breaks frontend widget parsing.
|
|
||||||
output = pop_pending_tool_output(tool_name) or (
|
|
||||||
_extract_tool_output(block.content)
|
|
||||||
)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=block.tool_use_id,
|
|
||||||
toolName=tool_name,
|
|
||||||
output=output,
|
|
||||||
success=not (block.is_error or False),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Close the current step after tool results — the next
|
|
||||||
# AssistantMessage will open a new step for the continuation.
|
|
||||||
if self.step_open:
|
|
||||||
responses.append(StreamFinishStep())
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, ResultMessage):
|
|
||||||
self._end_text_if_open(responses)
|
|
||||||
# Close the step before finishing.
|
|
||||||
if self.step_open:
|
|
||||||
responses.append(StreamFinishStep())
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
if sdk_message.subtype == "success":
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
|
||||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
|
||||||
responses.append(
|
|
||||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
|
||||||
)
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
|
||||||
)
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
|
||||||
"""Start (or restart) a text block if needed."""
|
|
||||||
if not self.has_started_text or self.has_ended_text:
|
|
||||||
if self.has_ended_text:
|
|
||||||
self.text_block_id = str(uuid.uuid4())
|
|
||||||
self.has_ended_text = False
|
|
||||||
responses.append(StreamTextStart(id=self.text_block_id))
|
|
||||||
self.has_started_text = True
|
|
||||||
|
|
||||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
|
||||||
"""End the current text block if one is open."""
|
|
||||||
if self.has_started_text and not self.has_ended_text:
|
|
||||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
|
||||||
self.has_ended_text = True
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
|
||||||
"""Extract a string output from a ToolResultBlock's content field."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
|
||||||
if parts:
|
|
||||||
return "".join(parts)
|
|
||||||
try:
|
|
||||||
return json.dumps(content)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return str(content)
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
return json.dumps(content)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return str(content)
|
|
||||||
@@ -1,366 +0,0 @@
|
|||||||
"""Unit tests for the SDK response adapter."""
|
|
||||||
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
ResultMessage,
|
|
||||||
SystemMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamFinishStep,
|
|
||||||
StreamStart,
|
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .response_adapter import SDKResponseAdapter
|
|
||||||
from .tool_adapter import MCP_TOOL_PREFIX
|
|
||||||
|
|
||||||
|
|
||||||
def _adapter() -> SDKResponseAdapter:
|
|
||||||
a = SDKResponseAdapter(message_id="msg-1")
|
|
||||||
a.set_task_id("task-1")
|
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
# -- SystemMessage -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_init_emits_start_and_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamStart)
|
|
||||||
assert results[0].messageId == "msg-1"
|
|
||||||
assert results[0].taskId == "task-1"
|
|
||||||
assert isinstance(results[1], StreamStartStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_non_init_emits_nothing():
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_block_emits_step_start_and_delta():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamTextStart)
|
|
||||||
assert isinstance(results[2], StreamTextDelta)
|
|
||||||
assert results[2].delta == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_text_block_emits_only_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# Empty text skipped, but step still opens
|
|
||||||
assert len(results) == 1
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_text_deltas_reuse_block_id():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
|
||||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
|
||||||
r1 = adapter.convert_message(msg1)
|
|
||||||
r2 = adapter.convert_message(msg2)
|
|
||||||
# First gets step+start+delta, second only delta (block & step already started)
|
|
||||||
assert len(r1) == 3
|
|
||||||
assert isinstance(r1[0], StreamStartStep)
|
|
||||||
assert isinstance(r1[1], StreamTextStart)
|
|
||||||
assert len(r2) == 1
|
|
||||||
assert isinstance(r2[0], StreamTextDelta)
|
|
||||||
assert r1[1].id == r2[0].id # same block ID
|
|
||||||
|
|
||||||
|
|
||||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_use_emits_input_start_and_available():
|
|
||||||
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(
|
|
||||||
id="tool-1",
|
|
||||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
|
||||||
input={"q": "x"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamToolInputStart)
|
|
||||||
assert results[1].toolCallId == "tool-1"
|
|
||||||
assert results[1].toolName == "find_agent" # prefix stripped
|
|
||||||
assert isinstance(results[2], StreamToolInputAvailable)
|
|
||||||
assert results[2].toolName == "find_agent" # prefix stripped
|
|
||||||
assert results[2].input == {"q": "x"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_then_tool_ends_text_block():
|
|
||||||
adapter = _adapter()
|
|
||||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
|
||||||
tool_msg = AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
adapter.convert_message(text_msg) # opens step + text
|
|
||||||
results = adapter.convert_message(tool_msg)
|
|
||||||
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamTextEnd)
|
|
||||||
assert isinstance(results[1], StreamToolInputStart)
|
|
||||||
|
|
||||||
|
|
||||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_emits_output_and_finish_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
# First register the tool call (opens step) — SDK sends prefixed name
|
|
||||||
tool_msg = AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
adapter.convert_message(tool_msg)
|
|
||||||
|
|
||||||
# Now send tool result
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].toolCallId == "t1"
|
|
||||||
assert results[0].toolName == "find_agent" # prefix stripped
|
|
||||||
assert results[0].output == "found 3 agents"
|
|
||||||
assert results[0].success is True
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_error():
|
|
||||||
adapter = _adapter()
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].success is False
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_list_content():
|
|
||||||
adapter = _adapter()
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[
|
|
||||||
ToolResultBlock(
|
|
||||||
tool_use_id="t1",
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "line1"},
|
|
||||||
{"type": "text", "text": "line2"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].output == "line1line2"
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_string_user_message_ignored():
|
|
||||||
"""A plain string UserMessage (not tool results) produces no output."""
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(UserMessage(content="hello"))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
# -- ResultMessage -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_success_emits_finish_step_and_finish():
|
|
||||||
adapter = _adapter()
|
|
||||||
# Start some text first (opens step)
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
|
||||||
)
|
|
||||||
msg = ResultMessage(
|
|
||||||
subtype="success",
|
|
||||||
duration_ms=100,
|
|
||||||
duration_api_ms=50,
|
|
||||||
is_error=False,
|
|
||||||
num_turns=1,
|
|
||||||
session_id="s1",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# TextEnd + FinishStep + StreamFinish
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamTextEnd)
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
assert isinstance(results[2], StreamFinish)
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_error_emits_error_and_finish():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = ResultMessage(
|
|
||||||
subtype="error",
|
|
||||||
duration_ms=100,
|
|
||||||
duration_api_ms=50,
|
|
||||||
is_error=True,
|
|
||||||
num_turns=0,
|
|
||||||
session_id="s1",
|
|
||||||
result="API rate limited",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# No step was open, so no FinishStep — just Error + Finish
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamError)
|
|
||||||
assert "API rate limited" in results[0].errorText
|
|
||||||
assert isinstance(results[1], StreamFinish)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Text after tools (new block ID) ----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_after_tool_gets_new_block_id():
|
|
||||||
adapter = _adapter()
|
|
||||||
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
|
||||||
)
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Send tool result (closes step)
|
|
||||||
adapter.convert_message(
|
|
||||||
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
|
||||||
)
|
|
||||||
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamTextStart)
|
|
||||||
assert isinstance(results[2], StreamTextDelta)
|
|
||||||
assert results[2].delta == "after"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Full conversation flow --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_conversation_flow():
|
|
||||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
|
||||||
adapter = _adapter()
|
|
||||||
all_responses: list[StreamBaseResponse] = []
|
|
||||||
|
|
||||||
# 1. Init
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
|
||||||
)
|
|
||||||
# 2. Assistant text
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 3. Tool use
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(
|
|
||||||
id="t1",
|
|
||||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
|
||||||
input={"query": "email"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 4. Tool result
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 5. More text
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 6. Result
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
ResultMessage(
|
|
||||||
subtype="success",
|
|
||||||
duration_ms=500,
|
|
||||||
duration_api_ms=400,
|
|
||||||
is_error=False,
|
|
||||||
num_turns=2,
|
|
||||||
session_id="s1",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
types = [type(r).__name__ for r in all_responses]
|
|
||||||
assert types == [
|
|
||||||
"StreamStart",
|
|
||||||
"StreamStartStep", # step 1: text + tool call
|
|
||||||
"StreamTextStart",
|
|
||||||
"StreamTextDelta", # "Let me search"
|
|
||||||
"StreamTextEnd", # closed before tool
|
|
||||||
"StreamToolInputStart",
|
|
||||||
"StreamToolInputAvailable",
|
|
||||||
"StreamToolOutputAvailable", # tool result
|
|
||||||
"StreamFinishStep", # step 1 closed after tool result
|
|
||||||
"StreamStartStep", # step 2: continuation text
|
|
||||||
"StreamTextStart", # new block after tool
|
|
||||||
"StreamTextDelta", # "I found 2"
|
|
||||||
"StreamTextEnd", # closed by result
|
|
||||||
"StreamFinishStep", # step 2 closed
|
|
||||||
"StreamFinish",
|
|
||||||
]
|
|
||||||
@@ -1,335 +0,0 @@
|
|||||||
"""Security hooks for Claude Agent SDK integration.
|
|
||||||
|
|
||||||
This module provides security hooks that validate tool calls before execution,
|
|
||||||
ensuring multi-user isolation and preventing unauthorized operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that are blocked entirely (CLI/system access).
|
|
||||||
# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked
|
|
||||||
# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead,
|
|
||||||
# which has kernel-level network isolation (unshare --net).
|
|
||||||
BLOCKED_TOOLS = {
|
|
||||||
"Bash",
|
|
||||||
"bash",
|
|
||||||
"shell",
|
|
||||||
"exec",
|
|
||||||
"terminal",
|
|
||||||
"command",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Tools allowed only when their path argument stays within the SDK workspace.
|
|
||||||
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
|
||||||
# files, then reads them back) and for workspace file operations.
|
|
||||||
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
|
||||||
|
|
||||||
# Dangerous patterns in tool inputs
|
|
||||||
DANGEROUS_PATTERNS = [
|
|
||||||
r"sudo",
|
|
||||||
r"rm\s+-rf",
|
|
||||||
r"dd\s+if=",
|
|
||||||
r"/etc/passwd",
|
|
||||||
r"/etc/shadow",
|
|
||||||
r"chmod\s+777",
|
|
||||||
r"curl\s+.*\|.*sh",
|
|
||||||
r"wget\s+.*\|.*sh",
|
|
||||||
r"eval\s*\(",
|
|
||||||
r"exec\s*\(",
|
|
||||||
r"__import__",
|
|
||||||
r"os\.system",
|
|
||||||
r"subprocess",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _deny(reason: str) -> dict[str, Any]:
|
|
||||||
"""Return a hook denial response."""
|
|
||||||
return {
|
|
||||||
"hookSpecificOutput": {
|
|
||||||
"hookEventName": "PreToolUse",
|
|
||||||
"permissionDecision": "deny",
|
|
||||||
"permissionDecisionReason": reason,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_workspace_path(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
|
||||||
|
|
||||||
Allowed directories:
|
|
||||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
|
||||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
|
||||||
"""
|
|
||||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
|
||||||
if not path:
|
|
||||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
|
||||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
|
||||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
|
||||||
if path.startswith("~"):
|
|
||||||
resolved = os.path.realpath(os.path.expanduser(path))
|
|
||||||
elif not os.path.isabs(path) and sdk_cwd:
|
|
||||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
|
||||||
else:
|
|
||||||
resolved = os.path.realpath(path)
|
|
||||||
|
|
||||||
# Allow access within the SDK working directory
|
|
||||||
if sdk_cwd:
|
|
||||||
norm_cwd = os.path.realpath(sdk_cwd)
|
|
||||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
|
||||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
|
||||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
|
||||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
|
||||||
)
|
|
||||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
|
||||||
return _deny(
|
|
||||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
|
||||||
f"directory.{workspace_hint} "
|
|
||||||
"This is enforced by the platform and cannot be bypassed."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_tool_access(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that a tool call is allowed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
|
||||||
"""
|
|
||||||
# Block forbidden tools
|
|
||||||
if tool_name in BLOCKED_TOOLS:
|
|
||||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
|
||||||
return _deny(
|
|
||||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
|
||||||
"This is enforced by the platform and cannot be bypassed. "
|
|
||||||
"Use the CoPilot-specific MCP tools instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
|
||||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
|
||||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
|
||||||
|
|
||||||
# Check for dangerous patterns in tool input
|
|
||||||
# Use json.dumps for predictable format (str() produces Python repr)
|
|
||||||
input_str = json.dumps(tool_input) if tool_input else ""
|
|
||||||
|
|
||||||
for pattern in DANGEROUS_PATTERNS:
|
|
||||||
if re.search(pattern, input_str, re.IGNORECASE):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
|
||||||
)
|
|
||||||
return _deny(
|
|
||||||
"[SECURITY] Input contains a blocked pattern. "
|
|
||||||
"This is enforced by the platform and cannot be bypassed."
|
|
||||||
)
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_user_isolation(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that tool calls respect user isolation."""
|
|
||||||
# For workspace file tools, ensure path doesn't escape
|
|
||||||
if "workspace" in tool_name.lower():
|
|
||||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
|
||||||
if path:
|
|
||||||
# Check for path traversal
|
|
||||||
if ".." in path or path.startswith("/"):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"hookSpecificOutput": {
|
|
||||||
"hookEventName": "PreToolUse",
|
|
||||||
"permissionDecision": "deny",
|
|
||||||
"permissionDecisionReason": "Path traversal not allowed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def create_security_hooks(
|
|
||||||
user_id: str | None,
|
|
||||||
sdk_cwd: str | None = None,
|
|
||||||
max_subtasks: int = 3,
|
|
||||||
on_stop: Callable[[str, str], None] | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create the security hooks configuration for Claude Agent SDK.
|
|
||||||
|
|
||||||
Includes security validation and observability hooks:
|
|
||||||
- PreToolUse: Security validation before tool execution
|
|
||||||
- PostToolUse: Log successful tool executions
|
|
||||||
- PostToolUseFailure: Log and handle failed tool executions
|
|
||||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
|
||||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Current user ID for isolation validation
|
|
||||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
|
||||||
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
|
|
||||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
|
||||||
the SDK finishes processing — used to read the JSONL transcript
|
|
||||||
before the CLI process exits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Hooks configuration dict for ClaudeAgentOptions
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import HookMatcher
|
|
||||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
|
||||||
|
|
||||||
# Per-session counter for Task sub-agent spawns
|
|
||||||
task_spawn_count = 0
|
|
||||||
|
|
||||||
async def pre_tool_use_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Combined pre-tool-use validation hook."""
|
|
||||||
nonlocal task_spawn_count
|
|
||||||
_ = context # unused but required by signature
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
|
||||||
|
|
||||||
# Rate-limit Task (sub-agent) spawns per session
|
|
||||||
if tool_name == "Task":
|
|
||||||
task_spawn_count += 1
|
|
||||||
if task_spawn_count > max_subtasks:
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
|
||||||
)
|
|
||||||
return cast(
|
|
||||||
SyncHookJSONOutput,
|
|
||||||
_deny(
|
|
||||||
f"Maximum {max_subtasks} sub-tasks per session. "
|
|
||||||
"Please continue in the main conversation."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Strip MCP prefix for consistent validation
|
|
||||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
|
||||||
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
|
||||||
|
|
||||||
# Only block non-CoPilot tools; our MCP-registered tools
|
|
||||||
# (including Read for oversized results) are already sandboxed.
|
|
||||||
if not is_copilot_tool:
|
|
||||||
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
|
||||||
if result:
|
|
||||||
return cast(SyncHookJSONOutput, result)
|
|
||||||
|
|
||||||
# Validate user isolation
|
|
||||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
|
||||||
if result:
|
|
||||||
return cast(SyncHookJSONOutput, result)
|
|
||||||
|
|
||||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def post_tool_use_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log successful tool executions for observability."""
|
|
||||||
_ = context
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def post_tool_failure_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log failed tool executions for debugging."""
|
|
||||||
_ = context
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
error = input_data.get("error", "Unknown error")
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
|
||||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
|
||||||
)
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def pre_compact_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log when SDK triggers context compaction.
|
|
||||||
|
|
||||||
The SDK automatically compacts conversation history when it grows too large.
|
|
||||||
This hook provides visibility into when compaction happens.
|
|
||||||
"""
|
|
||||||
_ = context, tool_use_id
|
|
||||||
trigger = input_data.get("trigger", "auto")
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
|
||||||
)
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
# --- Stop hook: capture transcript path for stateless resume ---
|
|
||||||
async def stop_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Capture transcript path when SDK finishes processing.
|
|
||||||
|
|
||||||
The Stop hook fires while the CLI process is still alive, giving us
|
|
||||||
a reliable window to read the JSONL transcript before SIGTERM.
|
|
||||||
"""
|
|
||||||
_ = context, tool_use_id
|
|
||||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
|
||||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
|
||||||
|
|
||||||
if transcript_path and on_stop:
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
|
||||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
|
||||||
)
|
|
||||||
on_stop(transcript_path, sdk_session_id)
|
|
||||||
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
hooks: dict[str, Any] = {
|
|
||||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
|
||||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
|
||||||
"PostToolUseFailure": [
|
|
||||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
|
||||||
],
|
|
||||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
|
||||||
}
|
|
||||||
|
|
||||||
if on_stop is not None:
|
|
||||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
|
||||||
|
|
||||||
return hooks
|
|
||||||
except ImportError:
|
|
||||||
# Fallback for when SDK isn't available - return empty hooks
|
|
||||||
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
|
||||||
return {}
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Unit tests for SDK security hooks."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-abc123"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_denied(result: dict) -> bool:
|
|
||||||
hook = result.get("hookSpecificOutput", {})
|
|
||||||
return hook.get("permissionDecision") == "deny"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Blocked tools -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_blocked_tools_denied():
|
|
||||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
|
||||||
result = _validate_tool_access(tool, {})
|
|
||||||
assert _is_denied(result), f"{tool} should be blocked"
|
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_tool_allowed():
|
|
||||||
result = _validate_tool_access("SomeCustomTool", {})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
# -- Workspace-scoped tools --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_edit_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_glob_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_grep_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_traversal_attack_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read",
|
|
||||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
|
||||||
sdk_cwd=SDK_CWD,
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_path_allowed():
|
|
||||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
|
||||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_no_cwd_denies_absolute():
|
|
||||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
|
||||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Tool-results directory --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_tool_results_allowed():
|
|
||||||
home = os.path.expanduser("~")
|
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_claude_projects_without_tool_results_denied():
|
|
||||||
home = os.path.expanduser("~")
|
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_builtin_always_blocked():
|
|
||||||
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
|
||||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Dangerous patterns ------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_dangerous_pattern_blocked():
|
|
||||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_subprocess_pattern_blocked():
|
|
||||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- User isolation ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_path_traversal_blocked():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_absolute_path_blocked():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_normal_path_allowed():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_workspace_tool_passes_isolation():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"find_agent", {"query": "email"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
@@ -1,751 +0,0 @@
|
|||||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
|
||||||
|
|
||||||
from .. import stream_registry
|
|
||||||
from ..config import ChatConfig
|
|
||||||
from ..model import (
|
|
||||||
ChatMessage,
|
|
||||||
ChatSession,
|
|
||||||
get_chat_session,
|
|
||||||
update_session_title,
|
|
||||||
upsert_chat_session,
|
|
||||||
)
|
|
||||||
from ..response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamStart,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
from ..service import (
|
|
||||||
_build_system_prompt,
|
|
||||||
_execute_long_running_tool_with_streaming,
|
|
||||||
_generate_session_title,
|
|
||||||
)
|
|
||||||
from ..tools.models import OperationPendingResponse, OperationStartedResponse
|
|
||||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
|
||||||
from ..tracking import track_user_message
|
|
||||||
from .response_adapter import SDKResponseAdapter
|
|
||||||
from .security_hooks import create_security_hooks
|
|
||||||
from .tool_adapter import (
|
|
||||||
COPILOT_TOOL_NAMES,
|
|
||||||
LongRunningCallback,
|
|
||||||
create_copilot_mcp_server,
|
|
||||||
set_execution_context,
|
|
||||||
)
|
|
||||||
from .transcript import (
|
|
||||||
download_transcript,
|
|
||||||
read_transcript_file,
|
|
||||||
upload_transcript,
|
|
||||||
validate_transcript,
|
|
||||||
write_transcript_to_tempfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Set to hold background tasks to prevent garbage collection
|
|
||||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CapturedTranscript:
|
|
||||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
|
||||||
|
|
||||||
path: str = ""
|
|
||||||
sdk_session_id: str = ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available(self) -> bool:
|
|
||||||
return bool(self.path)
|
|
||||||
|
|
||||||
|
|
||||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
|
||||||
|
|
||||||
# Appended to the system prompt to inform the agent about available tools.
|
|
||||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
|
||||||
# which has kernel-level network isolation (unshare --net).
|
|
||||||
_SDK_TOOL_SUPPLEMENT = """
|
|
||||||
|
|
||||||
## Tool notes
|
|
||||||
|
|
||||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
|
||||||
for shell commands — it runs in a network-isolated sandbox.
|
|
||||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
|
||||||
same working directory. Files created by one are readable by the other.
|
|
||||||
These files are **ephemeral** — they exist only for the current session.
|
|
||||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
|
||||||
for files that should persist across sessions (stored in cloud storage).
|
|
||||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
|
||||||
asynchronously. You will receive an immediate response; the actual result
|
|
||||||
is delivered to the user via a background stream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
|
||||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
|
||||||
|
|
||||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
|
||||||
existing background infrastructure: stream_registry (Redis Streams),
|
|
||||||
database persistence, and SSE reconnection. This means results survive
|
|
||||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
|
||||||
widget with progress updates.
|
|
||||||
|
|
||||||
The returned callback matches the ``LongRunningCallback`` signature:
|
|
||||||
``(tool_name, args, session) -> MCP response dict``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def _callback(
|
|
||||||
tool_name: str, args: dict[str, Any], session: ChatSession
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
operation_id = str(uuid.uuid4())
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
# --- Build user-friendly messages (matches non-SDK service) ---
|
|
||||||
if tool_name == "create_agent":
|
|
||||||
desc = args.get("description", "")
|
|
||||||
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
|
|
||||||
pending_msg = (
|
|
||||||
f"Creating your agent: {desc_preview}"
|
|
||||||
if desc_preview
|
|
||||||
else "Creating agent... This may take a few minutes."
|
|
||||||
)
|
|
||||||
started_msg = (
|
|
||||||
"Agent creation started. You can close this tab - "
|
|
||||||
"check your library in a few minutes."
|
|
||||||
)
|
|
||||||
elif tool_name == "edit_agent":
|
|
||||||
changes = args.get("changes", "")
|
|
||||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
|
||||||
pending_msg = (
|
|
||||||
f"Editing agent: {changes_preview}"
|
|
||||||
if changes_preview
|
|
||||||
else "Editing agent... This may take a few minutes."
|
|
||||||
)
|
|
||||||
started_msg = (
|
|
||||||
"Agent edit started. You can close this tab - "
|
|
||||||
"check your library in a few minutes."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
|
||||||
started_msg = (
|
|
||||||
f"{tool_name} started. You can close this tab - "
|
|
||||||
"check back in a few minutes."
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Register task in Redis for SSE reconnection ---
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Save OperationPendingResponse to chat history ---
|
|
||||||
pending_message = ChatMessage(
|
|
||||||
role="tool",
|
|
||||||
content=OperationPendingResponse(
|
|
||||||
message=pending_msg,
|
|
||||||
operation_id=operation_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
).model_dump_json(),
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
)
|
|
||||||
session.messages.append(pending_message)
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
|
||||||
bg_task = asyncio.create_task(
|
|
||||||
_execute_long_running_tool_with_streaming(
|
|
||||||
tool_name=tool_name,
|
|
||||||
parameters=args,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_background_tasks.add(bg_task)
|
|
||||||
bg_task.add_done_callback(_background_tasks.discard)
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Long-running tool {tool_name} delegated to background "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Return OperationStartedResponse as MCP tool result ---
|
|
||||||
# This flows through SDK → response adapter → frontend, triggering
|
|
||||||
# the loading widget with SSE reconnection support.
|
|
||||||
started_json = OperationStartedResponse(
|
|
||||||
message=started_msg,
|
|
||||||
operation_id=operation_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
task_id=task_id,
|
|
||||||
).model_dump_json()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": started_json}],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _callback
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_sdk_model() -> str | None:
|
|
||||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
|
||||||
|
|
||||||
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
|
||||||
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
|
||||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
|
||||||
"""
|
|
||||||
if config.claude_agent_model:
|
|
||||||
return config.claude_agent_model
|
|
||||||
model = config.model
|
|
||||||
if "/" in model:
|
|
||||||
return model.split("/", 1)[1]
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _build_sdk_env() -> dict[str, str]:
|
|
||||||
"""Build env vars for the SDK CLI process.
|
|
||||||
|
|
||||||
Routes API calls through OpenRouter (or a custom base_url) using
|
|
||||||
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
|
|
||||||
This gives per-call token and cost tracking on the OpenRouter dashboard.
|
|
||||||
|
|
||||||
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
|
|
||||||
token are both present — otherwise returns an empty dict so the SDK
|
|
||||||
falls back to its default credentials.
|
|
||||||
"""
|
|
||||||
env: dict[str, str] = {}
|
|
||||||
if config.api_key and config.base_url:
|
|
||||||
# Strip /v1 suffix — SDK expects the base URL without a version path
|
|
||||||
base = config.base_url.rstrip("/")
|
|
||||||
if base.endswith("/v1"):
|
|
||||||
base = base[:-3]
|
|
||||||
if not base or not base.startswith("http"):
|
|
||||||
# Invalid base_url — don't override SDK defaults
|
|
||||||
return env
|
|
||||||
env["ANTHROPIC_BASE_URL"] = base
|
|
||||||
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
|
|
||||||
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
|
||||||
env["ANTHROPIC_API_KEY"] = ""
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def _make_sdk_cwd(session_id: str) -> str:
|
|
||||||
"""Create a safe, session-specific working directory path.
|
|
||||||
|
|
||||||
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
|
||||||
(single source of truth for path sanitization) and adds a defence-in-depth
|
|
||||||
assertion.
|
|
||||||
"""
|
|
||||||
cwd = make_session_path(session_id)
|
|
||||||
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
|
|
||||||
cwd = os.path.normpath(cwd)
|
|
||||||
if not cwd.startswith(_SDK_CWD_PREFIX):
|
|
||||||
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
|
|
||||||
return cwd
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
|
||||||
"""Remove SDK tool-result files for a specific session working directory.
|
|
||||||
|
|
||||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
|
||||||
We clean only the specific cwd's results to avoid race conditions between
|
|
||||||
concurrent sessions.
|
|
||||||
|
|
||||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Validate cwd is under the expected prefix
|
|
||||||
normalized = os.path.normpath(cwd)
|
|
||||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
|
||||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# SDK encodes the cwd path by replacing '/' with '-'
|
|
||||||
encoded_cwd = normalized.replace("/", "-")
|
|
||||||
|
|
||||||
# Construct the project directory path (known-safe home expansion)
|
|
||||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
|
||||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
|
||||||
|
|
||||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
|
||||||
project_dir = os.path.normpath(project_dir)
|
|
||||||
if not project_dir.startswith(claude_projects):
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
results_dir = os.path.join(project_dir, "tool-results")
|
|
||||||
if os.path.isdir(results_dir):
|
|
||||||
for filename in os.listdir(results_dir):
|
|
||||||
file_path = os.path.join(results_dir, filename)
|
|
||||||
try:
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Also clean up the temp cwd directory itself
|
|
||||||
try:
|
|
||||||
shutil.rmtree(normalized, ignore_errors=True)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _compress_conversation_history(
|
|
||||||
session: ChatSession,
|
|
||||||
) -> list[ChatMessage]:
|
|
||||||
"""Compress prior conversation messages if they exceed the token threshold.
|
|
||||||
|
|
||||||
Uses the shared compress_context() from prompt.py which supports:
|
|
||||||
- LLM summarization of old messages (keeps recent ones intact)
|
|
||||||
- Progressive content truncation as fallback
|
|
||||||
- Middle-out deletion as last resort
|
|
||||||
|
|
||||||
Returns the compressed prior messages (everything except the current message).
|
|
||||||
"""
|
|
||||||
prior = session.messages[:-1]
|
|
||||||
if len(prior) < 2:
|
|
||||||
return prior
|
|
||||||
|
|
||||||
from backend.util.prompt import compress_context
|
|
||||||
|
|
||||||
# Convert ChatMessages to dicts for compress_context
|
|
||||||
messages_dict = []
|
|
||||||
for msg in prior:
|
|
||||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
|
||||||
if msg.content:
|
|
||||||
msg_dict["content"] = msg.content
|
|
||||||
if msg.tool_calls:
|
|
||||||
msg_dict["tool_calls"] = msg.tool_calls
|
|
||||||
if msg.tool_call_id:
|
|
||||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
|
||||||
messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
|
|
||||||
async with openai.AsyncOpenAI(
|
|
||||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
|
||||||
) as client:
|
|
||||||
result = await compress_context(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=config.model,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
|
||||||
# Fall back to truncation-only (no LLM summarization)
|
|
||||||
result = await compress_context(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=config.model,
|
|
||||||
client=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
|
||||||
f"{result.token_count} tokens "
|
|
||||||
f"({result.messages_summarized} summarized, "
|
|
||||||
f"{result.messages_dropped} dropped)"
|
|
||||||
)
|
|
||||||
# Convert compressed dicts back to ChatMessages
|
|
||||||
return [
|
|
||||||
ChatMessage(
|
|
||||||
role=m["role"],
|
|
||||||
content=m.get("content"),
|
|
||||||
tool_calls=m.get("tool_calls"),
|
|
||||||
tool_call_id=m.get("tool_call_id"),
|
|
||||||
)
|
|
||||||
for m in result.messages
|
|
||||||
]
|
|
||||||
|
|
||||||
return prior
|
|
||||||
|
|
||||||
|
|
||||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
|
||||||
"""Format conversation messages into a context prefix for the user message.
|
|
||||||
|
|
||||||
Returns a string like:
|
|
||||||
<conversation_history>
|
|
||||||
User: hello
|
|
||||||
You responded: Hi! How can I help?
|
|
||||||
</conversation_history>
|
|
||||||
|
|
||||||
Returns None if there are no messages to format.
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
lines: list[str] = []
|
|
||||||
for msg in messages:
|
|
||||||
if not msg.content:
|
|
||||||
continue
|
|
||||||
if msg.role == "user":
|
|
||||||
lines.append(f"User: {msg.content}")
|
|
||||||
elif msg.role == "assistant":
|
|
||||||
lines.append(f"You responded: {msg.content}")
|
|
||||||
# Skip tool messages — they're internal details
|
|
||||||
|
|
||||||
if not lines:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat_completion_sdk(
|
|
||||||
session_id: str,
|
|
||||||
message: str | None = None,
|
|
||||||
tool_call_response: str | None = None, # noqa: ARG001
|
|
||||||
is_user_message: bool = True,
|
|
||||||
user_id: str | None = None,
|
|
||||||
retry_count: int = 0, # noqa: ARG001
|
|
||||||
session: ChatSession | None = None,
|
|
||||||
context: dict[str, str] | None = None, # noqa: ARG001
|
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
|
||||||
"""Stream chat completion using Claude Agent SDK.
|
|
||||||
|
|
||||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise NotFoundError(
|
|
||||||
f"Session {session_id} not found. Please create a new session first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if message:
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="user" if is_user_message else "assistant", content=message
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# Generate title for new sessions (first user message)
|
|
||||||
if is_user_message and not session.title:
|
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
|
||||||
if len(user_messages) == 1:
|
|
||||||
first_message = user_messages[0].content or message or ""
|
|
||||||
if first_message:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
_update_title_async(session_id, first_message, user_id)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
|
|
||||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
|
||||||
has_history = len(session.messages) > 1
|
|
||||||
system_prompt, _ = await _build_system_prompt(
|
|
||||||
user_id, has_conversation_history=has_history
|
|
||||||
)
|
|
||||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
|
||||||
message_id = str(uuid.uuid4())
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
|
||||||
|
|
||||||
stream_completed = False
|
|
||||||
# Initialise sdk_cwd before the try so the finally can reference it
|
|
||||||
# even if _make_sdk_cwd raises (in that case it stays as "").
|
|
||||||
sdk_cwd = ""
|
|
||||||
use_resume = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
|
||||||
# between concurrent sessions.
|
|
||||||
sdk_cwd = _make_sdk_cwd(session_id)
|
|
||||||
os.makedirs(sdk_cwd, exist_ok=True)
|
|
||||||
|
|
||||||
set_execution_context(
|
|
||||||
user_id,
|
|
||||||
session,
|
|
||||||
long_running_callback=_build_long_running_callback(user_id),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
|
||||||
|
|
||||||
# Fail fast when no API credentials are available at all
|
|
||||||
sdk_env = _build_sdk_env()
|
|
||||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
|
||||||
raise RuntimeError(
|
|
||||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
|
||||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
|
||||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp_server = create_copilot_mcp_server()
|
|
||||||
|
|
||||||
sdk_model = _resolve_sdk_model()
|
|
||||||
|
|
||||||
# --- Transcript capture via Stop hook ---
|
|
||||||
captured_transcript = CapturedTranscript()
|
|
||||||
|
|
||||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
|
||||||
captured_transcript.path = transcript_path
|
|
||||||
captured_transcript.sdk_session_id = sdk_session_id
|
|
||||||
|
|
||||||
security_hooks = create_security_hooks(
|
|
||||||
user_id,
|
|
||||||
sdk_cwd=sdk_cwd,
|
|
||||||
max_subtasks=config.claude_agent_max_subtasks,
|
|
||||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Resume strategy: download transcript from bucket ---
|
|
||||||
resume_file: str | None = None
|
|
||||||
use_resume = False
|
|
||||||
|
|
||||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
|
||||||
transcript_content = await download_transcript(user_id, session_id)
|
|
||||||
if transcript_content and validate_transcript(transcript_content):
|
|
||||||
resume_file = write_transcript_to_tempfile(
|
|
||||||
transcript_content, session_id, sdk_cwd
|
|
||||||
)
|
|
||||||
if resume_file:
|
|
||||||
use_resume = True
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Using --resume with transcript "
|
|
||||||
f"({len(transcript_content)} bytes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
sdk_options_kwargs: dict[str, Any] = {
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
"mcp_servers": {"copilot": mcp_server},
|
|
||||||
"allowed_tools": COPILOT_TOOL_NAMES,
|
|
||||||
"disallowed_tools": ["Bash"],
|
|
||||||
"hooks": security_hooks,
|
|
||||||
"cwd": sdk_cwd,
|
|
||||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
|
||||||
}
|
|
||||||
if sdk_env:
|
|
||||||
sdk_options_kwargs["model"] = sdk_model
|
|
||||||
sdk_options_kwargs["env"] = sdk_env
|
|
||||||
if use_resume and resume_file:
|
|
||||||
sdk_options_kwargs["resume"] = resume_file
|
|
||||||
|
|
||||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
adapter = SDKResponseAdapter(message_id=message_id)
|
|
||||||
adapter.set_task_id(task_id)
|
|
||||||
|
|
||||||
async with ClaudeSDKClient(options=options) as client:
|
|
||||||
current_message = message or ""
|
|
||||||
if not current_message and session.messages:
|
|
||||||
last_user = [m for m in session.messages if m.role == "user"]
|
|
||||||
if last_user:
|
|
||||||
current_message = last_user[-1].content or ""
|
|
||||||
|
|
||||||
if not current_message.strip():
|
|
||||||
yield StreamError(
|
|
||||||
errorText="Message cannot be empty.",
|
|
||||||
code="empty_prompt",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build query: with --resume the CLI already has full
|
|
||||||
# context, so we only send the new message. Without
|
|
||||||
# resume, compress history into a context prefix.
|
|
||||||
query_message = current_message
|
|
||||||
if not use_resume and len(session.messages) > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Using compression fallback for session "
|
|
||||||
f"{session_id} ({len(session.messages)} messages) — "
|
|
||||||
f"no transcript available for --resume"
|
|
||||||
)
|
|
||||||
compressed = await _compress_conversation_history(session)
|
|
||||||
history_context = _format_conversation_context(compressed)
|
|
||||||
if history_context:
|
|
||||||
query_message = (
|
|
||||||
f"{history_context}\n\n"
|
|
||||||
f"Now, the user says:\n{current_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
|
||||||
)
|
|
||||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
|
||||||
await client.query(query_message, session_id=session_id)
|
|
||||||
|
|
||||||
assistant_response = ChatMessage(role="assistant", content="")
|
|
||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
|
||||||
has_appended_assistant = False
|
|
||||||
has_tool_results = False
|
|
||||||
|
|
||||||
async for sdk_msg in client.receive_messages():
|
|
||||||
logger.debug(
|
|
||||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
|
||||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
|
||||||
)
|
|
||||||
for response in adapter.convert_message(sdk_msg):
|
|
||||||
if isinstance(response, StreamStart):
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield response
|
|
||||||
|
|
||||||
if isinstance(response, StreamTextDelta):
|
|
||||||
delta = response.delta or ""
|
|
||||||
# After tool results, start a new assistant
|
|
||||||
# message for the post-tool text.
|
|
||||||
if has_tool_results and has_appended_assistant:
|
|
||||||
assistant_response = ChatMessage(
|
|
||||||
role="assistant", content=delta
|
|
||||||
)
|
|
||||||
accumulated_tool_calls = []
|
|
||||||
has_appended_assistant = False
|
|
||||||
has_tool_results = False
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
else:
|
|
||||||
assistant_response.content = (
|
|
||||||
assistant_response.content or ""
|
|
||||||
) + delta
|
|
||||||
if not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamToolInputAvailable):
|
|
||||||
accumulated_tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": response.toolCallId,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": response.toolName,
|
|
||||||
"arguments": json.dumps(response.input or {}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
|
||||||
if not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamToolOutputAvailable):
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="tool",
|
|
||||||
content=(
|
|
||||||
response.output
|
|
||||||
if isinstance(response.output, str)
|
|
||||||
else str(response.output)
|
|
||||||
),
|
|
||||||
tool_call_id=response.toolCallId,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
has_tool_results = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamFinish):
|
|
||||||
stream_completed = True
|
|
||||||
|
|
||||||
if stream_completed:
|
|
||||||
break
|
|
||||||
|
|
||||||
if (
|
|
||||||
assistant_response.content or assistant_response.tool_calls
|
|
||||||
) and not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
|
|
||||||
# --- Capture transcript while CLI is still alive ---
|
|
||||||
# Must happen INSIDE async with: close() sends SIGTERM
|
|
||||||
# which kills the CLI before it can flush the JSONL.
|
|
||||||
if (
|
|
||||||
config.claude_agent_use_resume
|
|
||||||
and user_id
|
|
||||||
and captured_transcript.available
|
|
||||||
):
|
|
||||||
# Give CLI time to flush JSONL writes before we read
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
|
||||||
if raw_transcript:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
else:
|
|
||||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"claude-agent-sdk is not installed. "
|
|
||||||
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
|
|
||||||
"to use the OpenAI-compatible fallback."
|
|
||||||
)
|
|
||||||
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
logger.debug(
|
|
||||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
|
||||||
)
|
|
||||||
if not stream_completed:
|
|
||||||
yield StreamFinish()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
|
||||||
try:
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
except Exception as save_err:
|
|
||||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="sdk_error",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
finally:
|
|
||||||
if sdk_cwd:
|
|
||||||
_cleanup_sdk_tool_results(sdk_cwd)
|
|
||||||
|
|
||||||
|
|
||||||
async def _upload_transcript_bg(
|
|
||||||
user_id: str, session_id: str, raw_content: str
|
|
||||||
) -> None:
|
|
||||||
"""Background task to strip progress entries and upload transcript."""
|
|
||||||
try:
|
|
||||||
await upload_transcript(user_id, session_id, raw_content)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_title_async(
|
|
||||||
session_id: str, message: str, user_id: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Background task to update session title."""
|
|
||||||
try:
|
|
||||||
title = await _generate_session_title(
|
|
||||||
message, user_id=user_id, session_id=session_id
|
|
||||||
)
|
|
||||||
if title:
|
|
||||||
await update_session_title(session_id, title)
|
|
||||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
|
||||||
@@ -1,325 +0,0 @@
|
|||||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
|
||||||
|
|
||||||
This module provides the adapter layer that converts existing BaseTool implementations
|
|
||||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
|
||||||
|
|
||||||
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
|
|
||||||
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
|
|
||||||
via a callback provided by the service layer. This avoids wasteful SDK polling
|
|
||||||
and makes results survive page refreshes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
|
||||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
|
||||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
|
||||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
|
||||||
|
|
||||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
|
||||||
MCP_SERVER_NAME = "copilot"
|
|
||||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
|
||||||
|
|
||||||
# Context variables to pass user/session info to tool execution
|
|
||||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
|
||||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
|
||||||
"current_session", default=None
|
|
||||||
)
|
|
||||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
|
||||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
|
||||||
# response adapter when it builds StreamToolOutputAvailable.
|
|
||||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
|
||||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
|
||||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
|
||||||
LongRunningCallback = Callable[
|
|
||||||
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
|
|
||||||
]
|
|
||||||
|
|
||||||
# ContextVar so the service layer can inject the callback per-request.
|
|
||||||
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
|
||||||
"long_running_callback", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_execution_context(
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
long_running_callback: LongRunningCallback | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Set the execution context for tool calls.
|
|
||||||
|
|
||||||
This must be called before streaming begins to ensure tools have access
|
|
||||||
to user_id and session information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Current user's ID.
|
|
||||||
session: Current chat session.
|
|
||||||
long_running_callback: Optional callback to delegate long-running tools
|
|
||||||
to the non-SDK background infrastructure (stream_registry + Redis).
|
|
||||||
"""
|
|
||||||
_current_user_id.set(user_id)
|
|
||||||
_current_session.set(session)
|
|
||||||
_pending_tool_outputs.set({})
|
|
||||||
_long_running_callback.set(long_running_callback)
|
|
||||||
|
|
||||||
|
|
||||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
|
||||||
"""Get the current execution context."""
|
|
||||||
return (
|
|
||||||
_current_user_id.get(),
|
|
||||||
_current_session.get(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
|
||||||
"""Pop and return the stashed full output for *tool_name*.
|
|
||||||
|
|
||||||
The SDK CLI may truncate large tool results (writing them to disk and
|
|
||||||
replacing the content with a file reference). This stash keeps the
|
|
||||||
original MCP output so the response adapter can forward it to the
|
|
||||||
frontend for proper widget rendering.
|
|
||||||
|
|
||||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
|
||||||
"""
|
|
||||||
pending = _pending_tool_outputs.get(None)
|
|
||||||
if pending is None:
|
|
||||||
return None
|
|
||||||
return pending.pop(tool_name, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _execute_tool_sync(
|
|
||||||
base_tool: BaseTool,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
args: dict[str, Any],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
|
||||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
|
||||||
result = await base_tool.execute(
|
|
||||||
user_id=user_id,
|
|
||||||
session=session,
|
|
||||||
tool_call_id=effective_id,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
|
|
||||||
text = (
|
|
||||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stash the full output before the SDK potentially truncates it.
|
|
||||||
pending = _pending_tool_outputs.get(None)
|
|
||||||
if pending is not None:
|
|
||||||
pending[base_tool.name] = text
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": text}],
|
|
||||||
"isError": not result.success,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _mcp_error(message: str) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
|
|
||||||
],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_tool_handler(base_tool: BaseTool):
|
|
||||||
"""Create an async handler function for a BaseTool.
|
|
||||||
|
|
||||||
This wraps the existing BaseTool._execute method to be compatible
|
|
||||||
with the Claude Agent SDK MCP tool format.
|
|
||||||
|
|
||||||
Long-running tools (``is_long_running=True``) are delegated to the
|
|
||||||
non-SDK background infrastructure via a callback set in the execution
|
|
||||||
context. The callback persists the operation in Redis (stream_registry)
|
|
||||||
so results survive page refreshes and pod restarts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
|
||||||
user_id, session = get_execution_context()
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
return _mcp_error("No session context available")
|
|
||||||
|
|
||||||
# --- Long-running: delegate to non-SDK background infrastructure ---
|
|
||||||
if base_tool.is_long_running:
|
|
||||||
callback = _long_running_callback.get(None)
|
|
||||||
if callback:
|
|
||||||
try:
|
|
||||||
return await callback(base_tool.name, args, session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Long-running callback failed for {base_tool.name}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
|
|
||||||
# No callback — fall through to synchronous execution
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] No long-running callback for {base_tool.name}, "
|
|
||||||
f"executing synchronously (may block)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Normal (fast) tool: execute synchronously ---
|
|
||||||
try:
|
|
||||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
|
||||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
|
||||||
|
|
||||||
return tool_handler
|
|
||||||
|
|
||||||
|
|
||||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
|
||||||
"""Build a JSON Schema input schema for a tool."""
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": base_tool.parameters.get("properties", {}),
|
|
||||||
"required": base_tool.parameters.get("required", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
|
||||||
|
|
||||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
|
||||||
"""
|
|
||||||
file_path = args.get("file_path", "")
|
|
||||||
offset = args.get("offset", 0)
|
|
||||||
limit = args.get("limit", 2000)
|
|
||||||
|
|
||||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
|
||||||
real_path = os.path.realpath(file_path)
|
|
||||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(real_path) as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
selected = lines[offset : offset + limit]
|
|
||||||
content = "".join(selected)
|
|
||||||
# Clean up to prevent accumulation in long-running pods
|
|
||||||
try:
|
|
||||||
os.remove(real_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
|
||||||
except FileNotFoundError:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_READ_TOOL_NAME = "Read"
|
|
||||||
_READ_TOOL_DESCRIPTION = (
|
|
||||||
"Read a file from the local filesystem. "
|
|
||||||
"Use offset and limit to read specific line ranges for large files."
|
|
||||||
)
|
|
||||||
_READ_TOOL_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The absolute path to the file to read",
|
|
||||||
},
|
|
||||||
"offset": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Line number to start reading from (0-indexed). Default: 0",
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Number of lines to read. Default: 2000",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["file_path"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Create the MCP server configuration
|
|
||||||
def create_copilot_mcp_server():
|
|
||||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
|
||||||
|
|
||||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
|
||||||
|
|
||||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
|
||||||
package being available. This function returns the configuration that
|
|
||||||
can be used with the SDK.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
|
||||||
|
|
||||||
# Create decorated tool functions
|
|
||||||
sdk_tools = []
|
|
||||||
|
|
||||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
|
||||||
handler = create_tool_handler(base_tool)
|
|
||||||
decorated = tool(
|
|
||||||
tool_name,
|
|
||||||
base_tool.description,
|
|
||||||
_build_input_schema(base_tool),
|
|
||||||
)(handler)
|
|
||||||
sdk_tools.append(decorated)
|
|
||||||
|
|
||||||
# Add the Read tool so the SDK can read back oversized tool results
|
|
||||||
read_tool = tool(
|
|
||||||
_READ_TOOL_NAME,
|
|
||||||
_READ_TOOL_DESCRIPTION,
|
|
||||||
_READ_TOOL_SCHEMA,
|
|
||||||
)(_read_file_handler)
|
|
||||||
sdk_tools.append(read_tool)
|
|
||||||
|
|
||||||
server = create_sdk_mcp_server(
|
|
||||||
name=MCP_SERVER_NAME,
|
|
||||||
version="1.0.0",
|
|
||||||
tools=sdk_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
return server
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
# Let ImportError propagate so service.py handles the fallback
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# SDK built-in tools allowed within the workspace directory.
|
|
||||||
# Security hooks validate that file paths stay within sdk_cwd.
|
|
||||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
|
||||||
# which provides kernel-level network isolation via unshare --net.
|
|
||||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
|
||||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"]
|
|
||||||
|
|
||||||
# List of tool names for allowed_tools configuration
|
|
||||||
# Include MCP tools, the MCP Read tool for oversized results,
|
|
||||||
# and SDK built-in file tools for workspace operations.
|
|
||||||
COPILOT_TOOL_NAMES = [
|
|
||||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
|
||||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
|
||||||
*_SDK_BUILTIN_TOOLS,
|
|
||||||
]
|
|
||||||
@@ -1,355 +0,0 @@
|
|||||||
"""JSONL transcript management for stateless multi-turn resume.
|
|
||||||
|
|
||||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
|
||||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
|
||||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
|
||||||
next turn we download the transcript, write it to a temp file, and pass
|
|
||||||
``--resume`` so the CLI can reconstruct the full conversation.
|
|
||||||
|
|
||||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
|
||||||
filesystem for self-hosted) — no DB column needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
|
||||||
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
|
|
||||||
|
|
||||||
# Entry types that can be safely removed from the transcript without breaking
|
|
||||||
# the parentUuid conversation tree that ``--resume`` relies on.
|
|
||||||
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
|
|
||||||
# - file-history-snapshot: undo tracking metadata
|
|
||||||
# - queue-operation: internal queue bookkeeping
|
|
||||||
# - summary: session summaries
|
|
||||||
# - pr-link: PR link metadata
|
|
||||||
STRIPPABLE_TYPES = frozenset(
|
|
||||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Workspace storage constants — deterministic path from session_id.
|
|
||||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Progress stripping
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def strip_progress_entries(content: str) -> str:
|
|
||||||
"""Remove progress/metadata entries from a JSONL transcript.
|
|
||||||
|
|
||||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
|
||||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
|
||||||
Typically reduces transcript size by ~30%.
|
|
||||||
"""
|
|
||||||
lines = content.strip().split("\n")
|
|
||||||
|
|
||||||
entries: list[dict] = []
|
|
||||||
for line in lines:
|
|
||||||
try:
|
|
||||||
entries.append(json.loads(line))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Keep unparseable lines as-is (safety)
|
|
||||||
entries.append({"_raw": line})
|
|
||||||
|
|
||||||
stripped_uuids: set[str] = set()
|
|
||||||
uuid_to_parent: dict[str, str] = {}
|
|
||||||
kept: list[dict] = []
|
|
||||||
|
|
||||||
for entry in entries:
|
|
||||||
if "_raw" in entry:
|
|
||||||
kept.append(entry)
|
|
||||||
continue
|
|
||||||
uid = entry.get("uuid", "")
|
|
||||||
parent = entry.get("parentUuid", "")
|
|
||||||
entry_type = entry.get("type", "")
|
|
||||||
|
|
||||||
if uid:
|
|
||||||
uuid_to_parent[uid] = parent
|
|
||||||
|
|
||||||
if entry_type in STRIPPABLE_TYPES:
|
|
||||||
if uid:
|
|
||||||
stripped_uuids.add(uid)
|
|
||||||
else:
|
|
||||||
kept.append(entry)
|
|
||||||
|
|
||||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
|
||||||
for entry in kept:
|
|
||||||
if "_raw" in entry:
|
|
||||||
continue
|
|
||||||
parent = entry.get("parentUuid", "")
|
|
||||||
original_parent = parent
|
|
||||||
while parent in stripped_uuids:
|
|
||||||
parent = uuid_to_parent.get(parent, "")
|
|
||||||
if parent != original_parent:
|
|
||||||
entry["parentUuid"] = parent
|
|
||||||
|
|
||||||
result_lines: list[str] = []
|
|
||||||
for entry in kept:
|
|
||||||
if "_raw" in entry:
|
|
||||||
result_lines.append(entry["_raw"])
|
|
||||||
else:
|
|
||||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
|
||||||
|
|
||||||
return "\n".join(result_lines) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def read_transcript_file(transcript_path: str) -> str | None:
|
|
||||||
"""Read a JSONL transcript file from disk.
|
|
||||||
|
|
||||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
|
||||||
or only contains metadata (≤2 lines with no conversation messages).
|
|
||||||
"""
|
|
||||||
if not transcript_path or not os.path.isfile(transcript_path):
|
|
||||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(transcript_path) as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
lines = content.strip().split("\n")
|
|
||||||
if len(lines) < 2:
|
|
||||||
# Metadata-only files have 1 line (single queue-operation or snapshot).
|
|
||||||
logger.debug(
|
|
||||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Quick structural validation — parse first and last lines.
|
|
||||||
json.loads(lines[0])
|
|
||||||
json.loads(lines[-1])
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Read {len(lines)} lines, "
|
|
||||||
f"{len(content)} bytes from {transcript_path}"
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, OSError) as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
|
||||||
"""Sanitize an ID for safe use in file paths.
|
|
||||||
|
|
||||||
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
|
|
||||||
everything else and truncate to *max_len* so the result cannot introduce
|
|
||||||
path separators or other special characters.
|
|
||||||
"""
|
|
||||||
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
|
|
||||||
return cleaned or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
|
||||||
|
|
||||||
|
|
||||||
def write_transcript_to_tempfile(
|
|
||||||
transcript_content: str,
|
|
||||||
session_id: str,
|
|
||||||
cwd: str,
|
|
||||||
) -> str | None:
|
|
||||||
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
|
|
||||||
|
|
||||||
The file lives in the session working directory so it is cleaned up
|
|
||||||
automatically when the session ends.
|
|
||||||
|
|
||||||
Returns the absolute path to the file, or ``None`` on failure.
|
|
||||||
"""
|
|
||||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
|
||||||
real_cwd = os.path.realpath(cwd)
|
|
||||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
|
||||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(real_cwd, exist_ok=True)
|
|
||||||
safe_id = _sanitize_id(session_id, max_len=8)
|
|
||||||
jsonl_path = os.path.realpath(
|
|
||||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
|
||||||
)
|
|
||||||
if not jsonl_path.startswith(real_cwd):
|
|
||||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
with open(jsonl_path, "w") as f:
|
|
||||||
f.write(transcript_content)
|
|
||||||
|
|
||||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
|
||||||
return jsonl_path
|
|
||||||
|
|
||||||
except OSError as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def validate_transcript(content: str | None) -> bool:
|
|
||||||
"""Check that a transcript has actual conversation messages.
|
|
||||||
|
|
||||||
A valid transcript for resume needs at least one user message and one
|
|
||||||
assistant message (not just queue-operation / file-history-snapshot
|
|
||||||
metadata).
|
|
||||||
"""
|
|
||||||
if not content or not content.strip():
|
|
||||||
return False
|
|
||||||
|
|
||||||
lines = content.strip().split("\n")
|
|
||||||
if len(lines) < 2:
|
|
||||||
return False
|
|
||||||
|
|
||||||
has_user = False
|
|
||||||
has_assistant = False
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
try:
|
|
||||||
entry = json.loads(line)
|
|
||||||
msg_type = entry.get("type")
|
|
||||||
if msg_type == "user":
|
|
||||||
has_user = True
|
|
||||||
elif msg_type == "assistant":
|
|
||||||
has_assistant = True
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return has_user and has_assistant
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Bucket storage (GCS / local via WorkspaceStorageBackend)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
|
||||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
|
||||||
|
|
||||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
|
||||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
TRANSCRIPT_STORAGE_PREFIX,
|
|
||||||
_sanitize_id(user_id),
|
|
||||||
f"{_sanitize_id(session_id)}.jsonl",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
|
||||||
"""Build the full storage path string that ``retrieve()`` expects.
|
|
||||||
|
|
||||||
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
|
||||||
``local://workspace_id/file_id/filename``. Since we use deterministic
|
|
||||||
arguments we can reconstruct the same path for download/delete without
|
|
||||||
having stored the return value.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
|
||||||
|
|
||||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
|
||||||
|
|
||||||
if isinstance(backend, GCSWorkspaceStorage):
|
|
||||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
|
||||||
return f"gcs://{backend.bucket_name}/{blob}"
|
|
||||||
else:
|
|
||||||
# LocalWorkspaceStorage returns local://{relative_path}
|
|
||||||
return f"local://{wid}/{fid}/{fname}"
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
|
||||||
"""Strip progress entries and upload transcript to bucket storage.
|
|
||||||
|
|
||||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
|
||||||
what is already stored. Since JSONL is append-only, the latest transcript
|
|
||||||
is always the longest. This prevents a slow/stale background task from
|
|
||||||
clobbering a newer upload from a concurrent turn.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
stripped = strip_progress_entries(content)
|
|
||||||
if not validate_transcript(stripped):
|
|
||||||
logger.warning(
|
|
||||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
|
||||||
f"transcript for session {session_id}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
|
||||||
encoded = stripped.encode("utf-8")
|
|
||||||
new_size = len(encoded)
|
|
||||||
|
|
||||||
# Check existing transcript size to avoid overwriting newer with older
|
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
|
||||||
try:
|
|
||||||
existing = await storage.retrieve(path)
|
|
||||||
if len(existing) >= new_size:
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Skipping upload — existing transcript "
|
|
||||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
|
||||||
f"{session_id}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
except (FileNotFoundError, Exception):
|
|
||||||
pass # No existing transcript or retrieval error — proceed with upload
|
|
||||||
|
|
||||||
await storage.store(
|
|
||||||
workspace_id=wid,
|
|
||||||
file_id=fid,
|
|
||||||
filename=fname,
|
|
||||||
content=encoded,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Uploaded {new_size} bytes "
|
|
||||||
f"(stripped from {len(content)}) for session {session_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
|
||||||
"""Download transcript from bucket storage.
|
|
||||||
|
|
||||||
Returns the JSONL content string, or ``None`` if not found.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await storage.retrieve(path)
|
|
||||||
content = data.decode("utf-8")
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
|
||||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await storage.delete(path)
|
|
||||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
|
||||||
@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(
|
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||||
user_id: str | None, has_conversation_history: bool = False
|
|
||||||
) -> tuple[str, Any]:
|
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID for fetching business understanding.
|
user_id: The user ID for fetching business understanding
|
||||||
has_conversation_history: Whether there's existing conversation history.
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
If True, we don't tell the model to greet/introduce (since they're
|
|
||||||
already in a conversation).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, business understanding object)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
@@ -270,8 +266,6 @@ async def _build_system_prompt(
|
|||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
elif has_conversation_history:
|
|
||||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
@@ -380,6 +374,7 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If session_id is invalid
|
NotFoundError: If session_id is invalid
|
||||||
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
completion_start = time.monotonic()
|
||||||
@@ -464,9 +459,8 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
# Check: is_user_message, no title yet, and this is the first user message
|
# Check: is_user_message, no title yet, and this is the first user message
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
if is_user_message and message and not session.title:
|
||||||
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
if is_user_message and first_user_msg and not session.title:
|
|
||||||
if len(user_messages) == 1:
|
if len(user_messages) == 1:
|
||||||
# First user message - generate title in background
|
# First user message - generate title in background
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -474,7 +468,7 @@ async def stream_chat_completion(
|
|||||||
# Capture only the values we need (not the session object) to avoid
|
# Capture only the values we need (not the session object) to avoid
|
||||||
# stale data issues when the main flow modifies the session
|
# stale data issues when the main flow modifies the session
|
||||||
captured_session_id = session_id
|
captured_session_id = session_id
|
||||||
captured_message = first_user_msg
|
captured_message = message
|
||||||
captured_user_id = user_id
|
captured_user_id = user_id
|
||||||
|
|
||||||
async def _update_title():
|
async def _update_title():
|
||||||
@@ -1243,7 +1237,7 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
@@ -12,8 +11,6 @@ from .response_model import (
|
|||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
)
|
)
|
||||||
from .sdk import service as sdk_service
|
|
||||||
from .sdk.transcript import download_transcript
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -83,96 +80,3 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
|||||||
session = await get_chat_session(session.session_id)
|
session = await get_chat_session(session.session_id)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
assert session.usage, "Usage is empty"
|
assert session.usage, "Usage is empty"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
|
||||||
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
|
||||||
|
|
||||||
Turn 1: Send a message containing a unique keyword.
|
|
||||||
Turn 2: Ask the model to recall that keyword — proving the transcript was
|
|
||||||
persisted and restored via --resume.
|
|
||||||
"""
|
|
||||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
|
||||||
|
|
||||||
from .config import ChatConfig
|
|
||||||
|
|
||||||
cfg = ChatConfig()
|
|
||||||
if not cfg.claude_agent_use_resume:
|
|
||||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
|
||||||
|
|
||||||
session = await create_chat_session(test_user_id)
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# --- Turn 1: send a message with a unique keyword ---
|
|
||||||
keyword = "ZEPHYR42"
|
|
||||||
turn1_msg = (
|
|
||||||
f"Please remember this special keyword: {keyword}. "
|
|
||||||
"Just confirm you've noted it, keep your response brief."
|
|
||||||
)
|
|
||||||
turn1_text = ""
|
|
||||||
turn1_errors: list[str] = []
|
|
||||||
turn1_ended = False
|
|
||||||
|
|
||||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
|
||||||
session.session_id,
|
|
||||||
turn1_msg,
|
|
||||||
user_id=test_user_id,
|
|
||||||
):
|
|
||||||
if isinstance(chunk, StreamTextDelta):
|
|
||||||
turn1_text += chunk.delta
|
|
||||||
elif isinstance(chunk, StreamError):
|
|
||||||
turn1_errors.append(chunk.errorText)
|
|
||||||
elif isinstance(chunk, StreamFinish):
|
|
||||||
turn1_ended = True
|
|
||||||
|
|
||||||
assert turn1_ended, "Turn 1 did not finish"
|
|
||||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
|
||||||
assert turn1_text, "Turn 1 produced no text"
|
|
||||||
|
|
||||||
# Wait for background upload task to complete (retry up to 5s)
|
|
||||||
transcript = None
|
|
||||||
for _ in range(10):
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
transcript = await download_transcript(test_user_id, session.session_id)
|
|
||||||
if transcript:
|
|
||||||
break
|
|
||||||
assert transcript, (
|
|
||||||
"Transcript was not uploaded to bucket after turn 1 — "
|
|
||||||
"Stop hook may not have fired or transcript was too small"
|
|
||||||
)
|
|
||||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
|
||||||
|
|
||||||
# Reload session for turn 2
|
|
||||||
session = await get_chat_session(session.session_id, test_user_id)
|
|
||||||
assert session, "Session not found after turn 1"
|
|
||||||
|
|
||||||
# --- Turn 2: ask model to recall the keyword ---
|
|
||||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
|
||||||
turn2_text = ""
|
|
||||||
turn2_errors: list[str] = []
|
|
||||||
turn2_ended = False
|
|
||||||
|
|
||||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
|
||||||
session.session_id,
|
|
||||||
turn2_msg,
|
|
||||||
user_id=test_user_id,
|
|
||||||
session=session,
|
|
||||||
):
|
|
||||||
if isinstance(chunk, StreamTextDelta):
|
|
||||||
turn2_text += chunk.delta
|
|
||||||
elif isinstance(chunk, StreamError):
|
|
||||||
turn2_errors.append(chunk.errorText)
|
|
||||||
elif isinstance(chunk, StreamFinish):
|
|
||||||
turn2_ended = True
|
|
||||||
|
|
||||||
assert turn2_ended, "Turn 2 did not finish"
|
|
||||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
|
||||||
assert turn2_text, "Turn 2 produced no text"
|
|
||||||
assert keyword in turn2_text, (
|
|
||||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
|
||||||
f"Response: {turn2_text[:200]}"
|
|
||||||
)
|
|
||||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
|
||||||
|
|||||||
@@ -814,28 +814,6 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Auto-expire stale tasks that exceeded stream_timeout
|
|
||||||
created_at_str = meta.get("created_at", "")
|
|
||||||
if created_at_str:
|
|
||||||
try:
|
|
||||||
created_at = datetime.fromisoformat(created_at_str)
|
|
||||||
age_seconds = (
|
|
||||||
datetime.now(timezone.utc) - created_at
|
|
||||||
).total_seconds()
|
|
||||||
if age_seconds > config.stream_timeout:
|
|
||||||
logger.warning(
|
|
||||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
|
||||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
|
||||||
)
|
|
||||||
await mark_task_completed(task_id, "failed")
|
|
||||||
continue
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ from backend.api.features.chat.tracking import track_tool_called
|
|||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .bash_exec import BashExecTool
|
|
||||||
from .check_operation_status import CheckOperationStatusTool
|
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
@@ -21,7 +19,6 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
from .web_fetch import WebFetchTool
|
|
||||||
from .workspace_files import (
|
from .workspace_files import (
|
||||||
DeleteWorkspaceFileTool,
|
DeleteWorkspaceFileTool,
|
||||||
ListWorkspaceFilesTool,
|
ListWorkspaceFilesTool,
|
||||||
@@ -46,14 +43,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"run_agent": RunAgentTool(),
|
"run_agent": RunAgentTool(),
|
||||||
"run_block": RunBlockTool(),
|
"run_block": RunBlockTool(),
|
||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"check_operation_status": CheckOperationStatusTool(),
|
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
# Web fetch for safe URL retrieval
|
# Workspace tools for CoPilot file operations
|
||||||
"web_fetch": WebFetchTool(),
|
|
||||||
# Sandboxed code execution (bubblewrap)
|
|
||||||
"bash_exec": BashExecTool(),
|
|
||||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
|
||||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
|
||||||
|
|
||||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
|
||||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
|
||||||
read-only, writable workspace only, clean env, no network.
|
|
||||||
|
|
||||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
|
||||||
available (e.g. macOS development).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
BashExecResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.sandbox import (
|
|
||||||
get_workspace_dir,
|
|
||||||
has_full_sandbox,
|
|
||||||
run_sandboxed,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BashExecTool(BaseTool):
|
|
||||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "bash_exec"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
if not has_full_sandbox():
|
|
||||||
return (
|
|
||||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
|
||||||
"available on this platform. Do not call this tool."
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
"Execute a Bash command or script in a bubblewrap sandbox. "
|
|
||||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
|
||||||
"functions, etc.). "
|
|
||||||
"The sandbox shares the same working directory as the SDK Read/Write "
|
|
||||||
"tools — files created by either are accessible to both. "
|
|
||||||
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
|
||||||
"visible read-only, the per-session workspace is the only writable "
|
|
||||||
"path, environment variables are wiped (no secrets), all network "
|
|
||||||
"access is blocked at the kernel level, and resource limits are "
|
|
||||||
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
|
||||||
"Application code, configs, and other directories are NOT accessible. "
|
|
||||||
"To fetch web content, use the web_fetch tool instead. "
|
|
||||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
|
||||||
"Returns stdout and stderr. "
|
|
||||||
"Useful for file manipulation, data processing with Unix tools "
|
|
||||||
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Bash command or script to execute.",
|
|
||||||
},
|
|
||||||
"timeout": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": (
|
|
||||||
"Max execution time in seconds (default 30, max 120)."
|
|
||||||
),
|
|
||||||
"default": 30,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["command"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not has_full_sandbox():
|
|
||||||
return ErrorResponse(
|
|
||||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
|
||||||
error="sandbox_unavailable",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
command: str = (kwargs.get("command") or "").strip()
|
|
||||||
timeout: int = kwargs.get("timeout", 30)
|
|
||||||
|
|
||||||
if not command:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="No command provided.",
|
|
||||||
error="empty_command",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
workspace = get_workspace_dir(session_id or "default")
|
|
||||||
|
|
||||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
|
||||||
command=["bash", "-c", command],
|
|
||||||
cwd=workspace,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
return BashExecResponse(
|
|
||||||
message=(
|
|
||||||
"Execution timed out"
|
|
||||||
if timed_out
|
|
||||||
else f"Command executed (exit {exit_code})"
|
|
||||||
),
|
|
||||||
stdout=stdout,
|
|
||||||
stderr=stderr,
|
|
||||||
exit_code=exit_code,
|
|
||||||
timed_out=timed_out,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
"""CheckOperationStatusTool — query the status of a long-running operation."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
ErrorResponse,
|
|
||||||
ResponseType,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OperationStatusResponse(ToolResponseBase):
|
|
||||||
"""Response for check_operation_status tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STATUS
|
|
||||||
task_id: str
|
|
||||||
operation_id: str
|
|
||||||
status: str # "running", "completed", "failed"
|
|
||||||
tool_name: str | None = None
|
|
||||||
message: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class CheckOperationStatusTool(BaseTool):
|
|
||||||
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
|
|
||||||
|
|
||||||
The CoPilot uses this tool to report back to the user whether an
|
|
||||||
operation that was started earlier has completed, failed, or is still
|
|
||||||
running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "check_operation_status"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Check the current status of a long-running operation such as "
|
|
||||||
"create_agent or edit_agent. Accepts either an operation_id or "
|
|
||||||
"task_id from a previous operation_started response. "
|
|
||||||
"Returns the current status: running, completed, or failed."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"operation_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The operation_id from an operation_started response."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"task_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The task_id from an operation_started response. "
|
|
||||||
"Used as fallback if operation_id is not provided."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
from backend.api.features.chat import stream_registry
|
|
||||||
|
|
||||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
|
||||||
task_id = (kwargs.get("task_id") or "").strip()
|
|
||||||
|
|
||||||
if not operation_id and not task_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide an operation_id or task_id.",
|
|
||||||
error="missing_parameter",
|
|
||||||
)
|
|
||||||
|
|
||||||
task = None
|
|
||||||
if operation_id:
|
|
||||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
|
||||||
if task is None and task_id:
|
|
||||||
task = await stream_registry.get_task(task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
# Task not in Redis — it may have already expired (TTL).
|
|
||||||
# Check conversation history for the result instead.
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Operation not found — it may have already completed and "
|
|
||||||
"expired from the status tracker. Check the conversation "
|
|
||||||
"history for the result."
|
|
||||||
),
|
|
||||||
error="not_found",
|
|
||||||
)
|
|
||||||
|
|
||||||
status_messages = {
|
|
||||||
"running": (
|
|
||||||
f"The {task.tool_name or 'operation'} is still running. "
|
|
||||||
"Please wait for it to complete."
|
|
||||||
),
|
|
||||||
"completed": (
|
|
||||||
f"The {task.tool_name or 'operation'} has completed successfully."
|
|
||||||
),
|
|
||||||
"failed": f"The {task.tool_name or 'operation'} has failed.",
|
|
||||||
}
|
|
||||||
|
|
||||||
return OperationStatusResponse(
|
|
||||||
task_id=task.task_id,
|
|
||||||
operation_id=task.operation_id,
|
|
||||||
status=task.status,
|
|
||||||
tool_name=task.tool_name,
|
|
||||||
message=status_messages.get(task.status, f"Status: {task.status}"),
|
|
||||||
)
|
|
||||||
@@ -7,6 +7,7 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.api.features.chat.tools.models import (
|
||||||
BlockInfoSummary,
|
BlockInfoSummary,
|
||||||
|
BlockInputFieldInfo,
|
||||||
BlockListResponse,
|
BlockListResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
@@ -54,8 +55,7 @@ class FindBlockTool(BaseTool):
|
|||||||
"Blocks are reusable components that perform specific tasks like "
|
"Blocks are reusable components that perform specific tasks like "
|
||||||
"sending emails, making API calls, processing text, etc. "
|
"sending emails, making API calls, processing text, etc. "
|
||||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||||
"The response includes each block's id, name, and description. "
|
"The response includes each block's id, required_inputs, and input_schema."
|
||||||
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -124,7 +124,7 @@ class FindBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enrich results with block information
|
# Enrich results with full block information
|
||||||
blocks: list[BlockInfoSummary] = []
|
blocks: list[BlockInfoSummary] = []
|
||||||
for result in results:
|
for result in results:
|
||||||
block_id = result["content_id"]
|
block_id = result["content_id"]
|
||||||
@@ -141,12 +141,65 @@ class FindBlockTool(BaseTool):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Get input/output schemas
|
||||||
|
input_schema = {}
|
||||||
|
output_schema = {}
|
||||||
|
try:
|
||||||
|
input_schema = block.input_schema.jsonschema()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to generate input schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to generate output schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get categories from block instance
|
||||||
|
categories = []
|
||||||
|
if hasattr(block, "categories") and block.categories:
|
||||||
|
categories = [cat.value for cat in block.categories]
|
||||||
|
|
||||||
|
# Extract required inputs for easier use
|
||||||
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
|
if input_schema:
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required_fields = set(input_schema.get("required", []))
|
||||||
|
# Get credential field names to exclude from required inputs
|
||||||
|
credentials_fields = set(
|
||||||
|
block.input_schema.get_credentials_fields().keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
# Skip credential fields - they're handled separately
|
||||||
|
if field_name in credentials_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
required_inputs.append(
|
||||||
|
BlockInputFieldInfo(
|
||||||
|
name=field_name,
|
||||||
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
blocks.append(
|
blocks.append(
|
||||||
BlockInfoSummary(
|
BlockInfoSummary(
|
||||||
id=block_id,
|
id=block_id,
|
||||||
name=block.name,
|
name=block.name,
|
||||||
description=block.description or "",
|
description=block.description or "",
|
||||||
categories=[c.value for c in block.categories],
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -175,7 +228,8 @@ class FindBlockTool(BaseTool):
|
|||||||
return BlockListResponse(
|
return BlockListResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||||
"To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs."
|
"To execute a block, use run_block with the block's 'id' field "
|
||||||
|
"and provide 'input_data' matching the block's input_schema."
|
||||||
),
|
),
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
count=len(blocks),
|
count=len(blocks),
|
||||||
|
|||||||
@@ -18,13 +18,7 @@ _TEST_USER_ID = "test-user-find-block"
|
|||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
def make_mock_block(
|
||||||
block_id: str,
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
name: str,
|
|
||||||
block_type: BlockType,
|
|
||||||
disabled: bool = False,
|
|
||||||
input_schema: dict | None = None,
|
|
||||||
output_schema: dict | None = None,
|
|
||||||
credentials_fields: dict | None = None,
|
|
||||||
):
|
):
|
||||||
"""Create a mock block for testing."""
|
"""Create a mock block for testing."""
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
@@ -34,13 +28,10 @@ def make_mock_block(
|
|||||||
mock.block_type = block_type
|
mock.block_type = block_type
|
||||||
mock.disabled = disabled
|
mock.disabled = disabled
|
||||||
mock.input_schema = MagicMock()
|
mock.input_schema = MagicMock()
|
||||||
mock.input_schema.jsonschema.return_value = input_schema or {
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
"properties": {},
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = credentials_fields or {}
|
|
||||||
mock.output_schema = MagicMock()
|
mock.output_schema = MagicMock()
|
||||||
mock.output_schema.jsonschema.return_value = output_schema or {}
|
mock.output_schema.jsonschema.return_value = {}
|
||||||
mock.categories = []
|
mock.categories = []
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
@@ -146,241 +137,3 @@ class TestFindBlockFiltering:
|
|||||||
assert isinstance(response, BlockListResponse)
|
assert isinstance(response, BlockListResponse)
|
||||||
assert len(response.blocks) == 1
|
assert len(response.blocks) == 1
|
||||||
assert response.blocks[0].id == "normal-block-id"
|
assert response.blocks[0].id == "normal-block-id"
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_response_size_average_chars_per_block(self):
|
|
||||||
"""Measure average chars per block in the serialized response."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Realistic block definitions modeled after real blocks
|
|
||||||
block_defs = [
|
|
||||||
{
|
|
||||||
"id": "http-block-id",
|
|
||||||
"name": "Send Web Request",
|
|
||||||
"input_schema": {
|
|
||||||
"properties": {
|
|
||||||
"url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The URL to send the request to",
|
|
||||||
},
|
|
||||||
"method": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The HTTP method to use",
|
|
||||||
},
|
|
||||||
"headers": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Headers to include in the request",
|
|
||||||
},
|
|
||||||
"json_format": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "If true, send the body as JSON",
|
|
||||||
},
|
|
||||||
"body": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Form/JSON body payload",
|
|
||||||
},
|
|
||||||
"credentials": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "HTTP credentials",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["url", "method"],
|
|
||||||
},
|
|
||||||
"output_schema": {
|
|
||||||
"properties": {
|
|
||||||
"response": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "The response from the server",
|
|
||||||
},
|
|
||||||
"client_error": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Errors on 4xx status codes",
|
|
||||||
},
|
|
||||||
"server_error": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Errors on 5xx status codes",
|
|
||||||
},
|
|
||||||
"error": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Errors for all other exceptions",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"credentials_fields": {"credentials": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "email-block-id",
|
|
||||||
"name": "Send Email",
|
|
||||||
"input_schema": {
|
|
||||||
"properties": {
|
|
||||||
"to_email": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Recipient email address",
|
|
||||||
},
|
|
||||||
"subject": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Subject of the email",
|
|
||||||
},
|
|
||||||
"body": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Body of the email",
|
|
||||||
},
|
|
||||||
"config": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "SMTP Config",
|
|
||||||
},
|
|
||||||
"credentials": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "SMTP credentials",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["to_email", "subject", "body", "credentials"],
|
|
||||||
},
|
|
||||||
"output_schema": {
|
|
||||||
"properties": {
|
|
||||||
"status": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Status of the email sending operation",
|
|
||||||
},
|
|
||||||
"error": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Error message if sending failed",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"credentials_fields": {"credentials": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-code-block-id",
|
|
||||||
"name": "Claude Code",
|
|
||||||
"input_schema": {
|
|
||||||
"properties": {
|
|
||||||
"e2b_credentials": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "API key for E2B platform",
|
|
||||||
},
|
|
||||||
"anthropic_credentials": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "API key for Anthropic",
|
|
||||||
},
|
|
||||||
"prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Task or instruction for Claude Code",
|
|
||||||
},
|
|
||||||
"timeout": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Sandbox timeout in seconds",
|
|
||||||
},
|
|
||||||
"setup_commands": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "Shell commands to run before execution",
|
|
||||||
},
|
|
||||||
"working_directory": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Working directory for Claude Code",
|
|
||||||
},
|
|
||||||
"session_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Session ID to resume a conversation",
|
|
||||||
},
|
|
||||||
"sandbox_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Sandbox ID to reconnect to",
|
|
||||||
},
|
|
||||||
"conversation_history": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Previous conversation history",
|
|
||||||
},
|
|
||||||
"dispose_sandbox": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Whether to dispose sandbox after execution",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [
|
|
||||||
"e2b_credentials",
|
|
||||||
"anthropic_credentials",
|
|
||||||
"prompt",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"output_schema": {
|
|
||||||
"properties": {
|
|
||||||
"response": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Output from Claude Code execution",
|
|
||||||
},
|
|
||||||
"files": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "Files created/modified by Claude Code",
|
|
||||||
},
|
|
||||||
"conversation_history": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Full conversation history",
|
|
||||||
},
|
|
||||||
"session_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Session ID for this conversation",
|
|
||||||
},
|
|
||||||
"sandbox_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "ID of the sandbox instance",
|
|
||||||
},
|
|
||||||
"error": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Error message if execution failed",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"credentials_fields": {
|
|
||||||
"e2b_credentials": True,
|
|
||||||
"anthropic_credentials": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
search_results = [
|
|
||||||
{"content_id": d["id"], "score": 0.9 - i * 0.1}
|
|
||||||
for i, d in enumerate(block_defs)
|
|
||||||
]
|
|
||||||
mock_blocks = {
|
|
||||||
d["id"]: make_mock_block(
|
|
||||||
block_id=d["id"],
|
|
||||||
name=d["name"],
|
|
||||||
block_type=BlockType.STANDARD,
|
|
||||||
input_schema=d["input_schema"],
|
|
||||||
output_schema=d["output_schema"],
|
|
||||||
credentials_fields=d["credentials_fields"],
|
|
||||||
)
|
|
||||||
for d in block_defs
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, len(search_results)),
|
|
||||||
), patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=lambda bid: mock_blocks.get(bid),
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="test"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert response.count == len(block_defs)
|
|
||||||
|
|
||||||
total_chars = len(response.model_dump_json())
|
|
||||||
avg_chars = total_chars // response.count
|
|
||||||
|
|
||||||
# Print for visibility in test output
|
|
||||||
print(f"\nTotal response size: {total_chars} chars")
|
|
||||||
print(f"Number of blocks: {response.count}")
|
|
||||||
print(f"Average chars per block: {avg_chars}")
|
|
||||||
|
|
||||||
# The old response was ~90K for 10 blocks (~9K per block).
|
|
||||||
# Previous optimization reduced it to ~1.5K per block (no raw JSON schemas).
|
|
||||||
# Now with only id/name/description, we expect ~300 chars per block.
|
|
||||||
assert avg_chars < 500, (
|
|
||||||
f"Average chars per block ({avg_chars}) exceeds 500. "
|
|
||||||
f"Total response: {total_chars} chars for {response.count} blocks."
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class ResponseType(str, Enum):
|
|||||||
AGENT_SAVED = "agent_saved"
|
AGENT_SAVED = "agent_saved"
|
||||||
CLARIFICATION_NEEDED = "clarification_needed"
|
CLARIFICATION_NEEDED = "clarification_needed"
|
||||||
BLOCK_LIST = "block_list"
|
BLOCK_LIST = "block_list"
|
||||||
BLOCK_DETAILS = "block_details"
|
|
||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
@@ -41,12 +40,6 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
# Input validation
|
# Input validation
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||||
# Web fetch
|
|
||||||
WEB_FETCH = "web_fetch"
|
|
||||||
# Code execution
|
|
||||||
BASH_EXEC = "bash_exec"
|
|
||||||
# Operation status check
|
|
||||||
OPERATION_STATUS = "operation_status"
|
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -342,17 +335,11 @@ class BlockInfoSummary(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
categories: list[str]
|
categories: list[str]
|
||||||
input_schema: dict[str, Any] = Field(
|
input_schema: dict[str, Any]
|
||||||
default_factory=dict,
|
output_schema: dict[str, Any]
|
||||||
description="Full JSON schema for block inputs",
|
|
||||||
)
|
|
||||||
output_schema: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Full JSON schema for block outputs",
|
|
||||||
)
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of input fields for this block",
|
description="List of required input fields for this block",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -365,29 +352,10 @@ class BlockListResponse(ToolResponseBase):
|
|||||||
query: str
|
query: str
|
||||||
usage_hint: str = Field(
|
usage_hint: str = Field(
|
||||||
default="To execute a block, call run_block with block_id set to the block's "
|
default="To execute a block, call run_block with block_id set to the block's "
|
||||||
"'id' field and input_data containing the fields listed in required_inputs."
|
"'id' field and input_data containing the required fields from input_schema."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockDetails(BaseModel):
|
|
||||||
"""Detailed block information."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
inputs: dict[str, Any] = {}
|
|
||||||
outputs: dict[str, Any] = {}
|
|
||||||
credentials: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
|
|
||||||
class BlockDetailsResponse(ToolResponseBase):
|
|
||||||
"""Response for block details (first run_block attempt)."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.BLOCK_DETAILS
|
|
||||||
block: BlockDetails
|
|
||||||
user_authenticated: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class BlockOutputResponse(ToolResponseBase):
|
class BlockOutputResponse(ToolResponseBase):
|
||||||
"""Response for run_block tool."""
|
"""Response for run_block tool."""
|
||||||
|
|
||||||
@@ -453,24 +421,3 @@ class AsyncProcessingResponse(ToolResponseBase):
|
|||||||
status: str = "accepted" # Must be "accepted" for detection
|
status: str = "accepted" # Must be "accepted" for detection
|
||||||
operation_id: str | None = None
|
operation_id: str | None = None
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class WebFetchResponse(ToolResponseBase):
|
|
||||||
"""Response for web_fetch tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WEB_FETCH
|
|
||||||
url: str
|
|
||||||
status_code: int
|
|
||||||
content_type: str
|
|
||||||
content: str
|
|
||||||
truncated: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class BashExecResponse(ToolResponseBase):
|
|
||||||
"""Response for bash_exec tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.BASH_EXEC
|
|
||||||
stdout: str
|
|
||||||
stderr: str
|
|
||||||
exit_code: int
|
|
||||||
timed_out: bool = False
|
|
||||||
|
|||||||
@@ -23,11 +23,8 @@ from backend.util.exceptions import BlockError
|
|||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .helpers import get_inputs_from_schema
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockDetails,
|
|
||||||
BlockDetailsResponse,
|
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
InputValidationErrorResponse,
|
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -54,8 +51,8 @@ class RunBlockTool(BaseTool):
|
|||||||
"Execute a specific block with the provided input data. "
|
"Execute a specific block with the provided input data. "
|
||||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||||
"do NOT guess or make up block IDs. "
|
"do NOT guess or make up block IDs. "
|
||||||
"On first attempt (without input_data), returns detailed schema showing "
|
"Use the 'id' from find_block results and provide input_data "
|
||||||
"required inputs and outputs. Then call again with proper input_data to execute."
|
"matching the block's required_inputs."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -70,19 +67,11 @@ class RunBlockTool(BaseTool):
|
|||||||
"NEVER guess this - always get it from find_block first."
|
"NEVER guess this - always get it from find_block first."
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"block_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The block's human-readable name from find_block results. "
|
|
||||||
"Used for display purposes in the UI."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"input_data": {
|
"input_data": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": (
|
"description": (
|
||||||
"Input values for the block. "
|
"Input values for the block. Use the 'required_inputs' field "
|
||||||
"First call with empty {} to see the block's schema, "
|
"from find_block to see what fields are needed."
|
||||||
"then call again with proper values to execute."
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -167,34 +156,6 @@ class RunBlockTool(BaseTool):
|
|||||||
await self._resolve_block_credentials(user_id, block, input_data)
|
await self._resolve_block_credentials(user_id, block, input_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get block schemas for details/validation
|
|
||||||
try:
|
|
||||||
input_schema: dict[str, Any] = block.input_schema.jsonschema()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to generate input schema for block %s: %s",
|
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Block '{block.name}' has an invalid input schema",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
output_schema: dict[str, Any] = block.output_schema.jsonschema()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to generate output schema for block %s: %s",
|
|
||||||
block_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Block '{block.name}' has an invalid output schema",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
# Return setup requirements response with missing credentials
|
# Return setup requirements response with missing credentials
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
@@ -227,53 +188,6 @@ class RunBlockTool(BaseTool):
|
|||||||
graph_version=None,
|
graph_version=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if this is a first attempt (required inputs missing)
|
|
||||||
# Return block details so user can see what inputs are needed
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
|
||||||
required_keys = set(input_schema.get("required", []))
|
|
||||||
required_non_credential_keys = required_keys - credentials_fields
|
|
||||||
provided_input_keys = set(input_data.keys()) - credentials_fields
|
|
||||||
|
|
||||||
# Check for unknown input fields
|
|
||||||
valid_fields = (
|
|
||||||
set(input_schema.get("properties", {}).keys()) - credentials_fields
|
|
||||||
)
|
|
||||||
unrecognized_fields = provided_input_keys - valid_fields
|
|
||||||
if unrecognized_fields:
|
|
||||||
return InputValidationErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
|
||||||
f"Block was not executed. Please use the correct field names from the schema."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
unrecognized_fields=sorted(unrecognized_fields),
|
|
||||||
inputs=input_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show details when not all required non-credential inputs are provided
|
|
||||||
if not (required_non_credential_keys <= provided_input_keys):
|
|
||||||
# Get credentials info for the response
|
|
||||||
credentials_meta = []
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
|
||||||
credentials_meta.append(cred_meta)
|
|
||||||
|
|
||||||
return BlockDetailsResponse(
|
|
||||||
message=(
|
|
||||||
f"Block '{block.name}' details. "
|
|
||||||
"Provide input_data matching the inputs schema to execute the block."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
block=BlockDetails(
|
|
||||||
id=block_id,
|
|
||||||
name=block.name,
|
|
||||||
description=block.description or "",
|
|
||||||
inputs=input_schema,
|
|
||||||
outputs=output_schema,
|
|
||||||
credentials=credentials_meta,
|
|
||||||
),
|
|
||||||
user_authenticated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get or create user's workspace for CoPilot file operations
|
# Get or create user's workspace for CoPilot file operations
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|||||||
@@ -1,15 +1,10 @@
|
|||||||
"""Tests for block execution guards and input validation in RunBlockTool."""
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
BlockDetailsResponse,
|
|
||||||
BlockOutputResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
InputValidationErrorResponse,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
from backend.blocks._base import BlockType
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
@@ -33,39 +28,6 @@ def make_mock_block(
|
|||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block_with_schema(
|
|
||||||
block_id: str,
|
|
||||||
name: str,
|
|
||||||
input_properties: dict,
|
|
||||||
required_fields: list[str],
|
|
||||||
output_properties: dict | None = None,
|
|
||||||
):
|
|
||||||
"""Create a mock block with a defined input/output schema for validation tests."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.block_type = BlockType.STANDARD
|
|
||||||
mock.disabled = False
|
|
||||||
mock.description = f"Test block: {name}"
|
|
||||||
|
|
||||||
input_schema = {
|
|
||||||
"properties": input_properties,
|
|
||||||
"required": required_fields,
|
|
||||||
}
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = input_schema
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
|
|
||||||
output_schema = {
|
|
||||||
"properties": output_properties or {"result": {"type": "string"}},
|
|
||||||
}
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = output_schema
|
|
||||||
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunBlockFiltering:
|
class TestRunBlockFiltering:
|
||||||
"""Tests for block execution guards in RunBlockTool."""
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
@@ -142,221 +104,3 @@ class TestRunBlockFiltering:
|
|||||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
assert "cannot be run directly in CoPilot" not in response.message
|
assert "cannot be run directly in CoPilot" not in response.message
|
||||||
|
|
||||||
|
|
||||||
class TestRunBlockInputValidation:
|
|
||||||
"""Tests for input field validation in RunBlockTool.
|
|
||||||
|
|
||||||
run_block rejects unknown input field names with InputValidationErrorResponse,
|
|
||||||
preventing silent failures where incorrect keys would be ignored and the block
|
|
||||||
would execute with default values instead of the caller's intended values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_unknown_input_fields_are_rejected(self):
|
|
||||||
"""run_block rejects unknown input fields instead of silently ignoring them.
|
|
||||||
|
|
||||||
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
|
|
||||||
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
|
|
||||||
instead. The block should reject the request and return the valid schema.
|
|
||||||
"""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
mock_block = make_mock_block_with_schema(
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
name="AI Text Generator",
|
|
||||||
input_properties={
|
|
||||||
"prompt": {"type": "string", "description": "The prompt to send"},
|
|
||||||
"model": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The LLM model to use",
|
|
||||||
"default": "gpt-4o-mini",
|
|
||||||
},
|
|
||||||
"sys_prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "System prompt",
|
|
||||||
"default": "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required_fields=["prompt"],
|
|
||||||
output_properties={"response": {"type": "string"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=mock_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
|
|
||||||
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
input_data={
|
|
||||||
"prompt": "Write a haiku about coding",
|
|
||||||
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, InputValidationErrorResponse)
|
|
||||||
assert "LLM_Model" in response.unrecognized_fields
|
|
||||||
assert "Block was not executed" in response.message
|
|
||||||
assert "inputs" in response.model_dump() # valid schema included
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_multiple_wrong_keys_are_all_reported(self):
|
|
||||||
"""All unrecognized field names are reported in a single error response."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
mock_block = make_mock_block_with_schema(
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
name="AI Text Generator",
|
|
||||||
input_properties={
|
|
||||||
"prompt": {"type": "string"},
|
|
||||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
|
||||||
"sys_prompt": {"type": "string", "default": ""},
|
|
||||||
"retry": {"type": "integer", "default": 3},
|
|
||||||
},
|
|
||||||
required_fields=["prompt"],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=mock_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
input_data={
|
|
||||||
"prompt": "Hello", # correct
|
|
||||||
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
|
|
||||||
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
|
|
||||||
"retries": 5, # WRONG - should be 'retry'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, InputValidationErrorResponse)
|
|
||||||
assert set(response.unrecognized_fields) == {
|
|
||||||
"llm_model",
|
|
||||||
"system_prompt",
|
|
||||||
"retries",
|
|
||||||
}
|
|
||||||
assert "Block was not executed" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_unknown_fields_rejected_even_with_missing_required(self):
|
|
||||||
"""Unknown fields are caught before the missing-required-fields check."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
mock_block = make_mock_block_with_schema(
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
name="AI Text Generator",
|
|
||||||
input_properties={
|
|
||||||
"prompt": {"type": "string"},
|
|
||||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
|
||||||
},
|
|
||||||
required_fields=["prompt"],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=mock_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
|
|
||||||
# 'prompt' is missing AND 'LLM_Model' is an unknown field
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
input_data={
|
|
||||||
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unknown fields are caught first
|
|
||||||
assert isinstance(response, InputValidationErrorResponse)
|
|
||||||
assert "LLM_Model" in response.unrecognized_fields
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_correct_inputs_still_execute(self):
|
|
||||||
"""Correct input field names pass validation and the block executes."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
mock_block = make_mock_block_with_schema(
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
name="AI Text Generator",
|
|
||||||
input_properties={
|
|
||||||
"prompt": {"type": "string"},
|
|
||||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
|
||||||
},
|
|
||||||
required_fields=["prompt"],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_execute(input_data, **kwargs):
|
|
||||||
yield "response", "Generated text"
|
|
||||||
|
|
||||||
mock_block.execute = mock_execute
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=mock_block,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=MagicMock(id="test-workspace-id"),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
input_data={
|
|
||||||
"prompt": "Write a haiku",
|
|
||||||
"model": "gpt-4o-mini", # correct field name
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, BlockOutputResponse)
|
|
||||||
assert response.success is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_missing_required_fields_returns_details(self):
|
|
||||||
"""Missing required fields returns BlockDetailsResponse with schema."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
mock_block = make_mock_block_with_schema(
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
name="AI Text Generator",
|
|
||||||
input_properties={
|
|
||||||
"prompt": {"type": "string"},
|
|
||||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
|
||||||
},
|
|
||||||
required_fields=["prompt"],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=mock_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
|
|
||||||
# Only provide valid optional field, missing required 'prompt'
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="ai-text-gen-id",
|
|
||||||
input_data={
|
|
||||||
"model": "gpt-4o-mini", # valid but optional
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, BlockDetailsResponse)
|
|
||||||
|
|||||||
@@ -1,265 +0,0 @@
|
|||||||
"""Sandbox execution utilities for code execution tools.
|
|
||||||
|
|
||||||
Provides filesystem + network isolated command execution using **bubblewrap**
|
|
||||||
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
|
|
||||||
writable workspace only, clean environment, network blocked.
|
|
||||||
|
|
||||||
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
|
|
||||||
and refuse to run if bubblewrap is not available.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_DEFAULT_TIMEOUT = 30
|
|
||||||
_MAX_TIMEOUT = 120
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Sandbox capability detection (cached at first call)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_BWRAP_AVAILABLE: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def has_full_sandbox() -> bool:
|
|
||||||
"""Return True if bubblewrap is available (filesystem + network isolation).
|
|
||||||
|
|
||||||
On non-Linux platforms (macOS), always returns False.
|
|
||||||
"""
|
|
||||||
global _BWRAP_AVAILABLE
|
|
||||||
if _BWRAP_AVAILABLE is None:
|
|
||||||
_BWRAP_AVAILABLE = (
|
|
||||||
platform.system() == "Linux" and shutil.which("bwrap") is not None
|
|
||||||
)
|
|
||||||
return _BWRAP_AVAILABLE
|
|
||||||
|
|
||||||
|
|
||||||
WORKSPACE_PREFIX = "/tmp/copilot-"
|
|
||||||
|
|
||||||
|
|
||||||
def make_session_path(session_id: str) -> str:
|
|
||||||
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
|
|
||||||
|
|
||||||
Shared by both the SDK working-directory setup and the sandbox tools so
|
|
||||||
they always resolve to the same directory for a given session.
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. Strip all characters except ``[A-Za-z0-9-]``.
|
|
||||||
2. Construct ``/tmp/copilot-<safe_id>``.
|
|
||||||
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
|
|
||||||
sanitizer) to prevent path traversal.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the resulting path escapes the prefix.
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
|
||||||
if not safe_id:
|
|
||||||
safe_id = "default"
|
|
||||||
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
|
|
||||||
if not path.startswith(WORKSPACE_PREFIX):
|
|
||||||
raise ValueError(f"Session path escaped prefix: {path}")
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def get_workspace_dir(session_id: str) -> str:
|
|
||||||
"""Get or create the workspace directory for a session.
|
|
||||||
|
|
||||||
Uses :func:`make_session_path` — the same path the SDK uses — so that
|
|
||||||
bash_exec shares the workspace with the SDK file tools.
|
|
||||||
"""
|
|
||||||
workspace = make_session_path(session_id)
|
|
||||||
os.makedirs(workspace, exist_ok=True)
|
|
||||||
return workspace
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Bubblewrap command builder
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# System directories mounted read-only inside the sandbox.
|
|
||||||
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
|
|
||||||
_SYSTEM_RO_BINDS = [
|
|
||||||
"/usr", # binaries, libraries, Python interpreter
|
|
||||||
"/etc", # system config: ld.so, locale, passwd, alternatives
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
|
|
||||||
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
|
|
||||||
# can't create a symlink target, so we detect and use --symlink instead.
|
|
||||||
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
|
|
||||||
_COMPAT_PATHS = [
|
|
||||||
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
|
|
||||||
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
|
|
||||||
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
|
|
||||||
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
|
|
||||||
]
|
|
||||||
|
|
||||||
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
|
|
||||||
# Applied via ulimit inside the sandbox before exec'ing the user command.
|
|
||||||
_RESOURCE_LIMITS = (
|
|
||||||
"ulimit -u 64" # max 64 processes (prevents fork bombs)
|
|
||||||
" -v 524288" # 512 MB virtual memory
|
|
||||||
" -f 51200" # 50 MB max file size (1024-byte blocks)
|
|
||||||
" -n 256" # 256 open file descriptors
|
|
||||||
" 2>/dev/null"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_bwrap_command(
|
|
||||||
command: list[str], cwd: str, env: dict[str, str]
|
|
||||||
) -> list[str]:
|
|
||||||
"""Build a bubblewrap command with strict filesystem + network isolation.
|
|
||||||
|
|
||||||
Security model:
|
|
||||||
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
|
|
||||||
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
|
|
||||||
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
|
|
||||||
- **Writable workspace only**: the per-session workspace is the sole
|
|
||||||
writable path.
|
|
||||||
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
|
|
||||||
Only the explicitly-passed safe env vars are set inside the sandbox.
|
|
||||||
- **Network isolation**: ``--unshare-net`` blocks all network access.
|
|
||||||
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
|
|
||||||
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
|
|
||||||
- **New session**: prevents terminal control escape.
|
|
||||||
- **Die with parent**: prevents orphaned sandbox processes.
|
|
||||||
"""
|
|
||||||
cmd = [
|
|
||||||
"bwrap",
|
|
||||||
# Create a new user namespace so bwrap can set up sandboxing
|
|
||||||
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
|
|
||||||
"--unshare-user",
|
|
||||||
# Wipe all inherited environment variables (API keys, secrets, etc.)
|
|
||||||
"--clearenv",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Set only the safe env vars inside the sandbox
|
|
||||||
for key, value in env.items():
|
|
||||||
cmd.extend(["--setenv", key, value])
|
|
||||||
|
|
||||||
# System directories: read-only
|
|
||||||
for path in _SYSTEM_RO_BINDS:
|
|
||||||
cmd.extend(["--ro-bind", path, path])
|
|
||||||
|
|
||||||
# Compat paths: use --symlink when host path is a symlink (Debian 13),
|
|
||||||
# --ro-bind when it's a real directory (older distros).
|
|
||||||
for path, symlink_target in _COMPAT_PATHS:
|
|
||||||
if os.path.islink(path):
|
|
||||||
cmd.extend(["--symlink", symlink_target, path])
|
|
||||||
elif os.path.exists(path):
|
|
||||||
cmd.extend(["--ro-bind", path, path])
|
|
||||||
|
|
||||||
# Wrap the user command with resource limits:
|
|
||||||
# sh -c 'ulimit ...; exec "$@"' -- <original command>
|
|
||||||
# `exec "$@"` replaces the shell so there's no extra process overhead,
|
|
||||||
# and properly handles arguments with spaces.
|
|
||||||
limited_command = [
|
|
||||||
"sh",
|
|
||||||
"-c",
|
|
||||||
f'{_RESOURCE_LIMITS}; exec "$@"',
|
|
||||||
"--",
|
|
||||||
*command,
|
|
||||||
]
|
|
||||||
|
|
||||||
cmd.extend(
|
|
||||||
[
|
|
||||||
# Fresh virtual filesystems
|
|
||||||
"--dev",
|
|
||||||
"/dev",
|
|
||||||
"--proc",
|
|
||||||
"/proc",
|
|
||||||
"--tmpfs",
|
|
||||||
"/tmp",
|
|
||||||
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
|
|
||||||
# (workspace lives under /tmp/copilot-<session>)
|
|
||||||
"--bind",
|
|
||||||
cwd,
|
|
||||||
cwd,
|
|
||||||
# Isolation
|
|
||||||
"--unshare-net",
|
|
||||||
"--die-with-parent",
|
|
||||||
"--new-session",
|
|
||||||
"--chdir",
|
|
||||||
cwd,
|
|
||||||
"--",
|
|
||||||
*limited_command,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def run_sandboxed(
|
|
||||||
command: list[str],
|
|
||||||
cwd: str,
|
|
||||||
timeout: int = _DEFAULT_TIMEOUT,
|
|
||||||
env: dict[str, str] | None = None,
|
|
||||||
) -> tuple[str, str, int, bool]:
|
|
||||||
"""Run a command inside a bubblewrap sandbox.
|
|
||||||
|
|
||||||
Callers **must** check :func:`has_full_sandbox` before calling this
|
|
||||||
function. If bubblewrap is not available, this function raises
|
|
||||||
:class:`RuntimeError` rather than running unsandboxed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(stdout, stderr, exit_code, timed_out)
|
|
||||||
"""
|
|
||||||
if not has_full_sandbox():
|
|
||||||
raise RuntimeError(
|
|
||||||
"run_sandboxed() requires bubblewrap but bwrap is not available. "
|
|
||||||
"Callers must check has_full_sandbox() before calling this function."
|
|
||||||
)
|
|
||||||
|
|
||||||
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
|
|
||||||
|
|
||||||
safe_env = {
|
|
||||||
"PATH": "/usr/local/bin:/usr/bin:/bin",
|
|
||||||
"HOME": cwd,
|
|
||||||
"TMPDIR": cwd,
|
|
||||||
"LANG": "en_US.UTF-8",
|
|
||||||
"PYTHONDONTWRITEBYTECODE": "1",
|
|
||||||
"PYTHONIOENCODING": "utf-8",
|
|
||||||
}
|
|
||||||
if env:
|
|
||||||
safe_env.update(env)
|
|
||||||
|
|
||||||
full_command = _build_bwrap_command(command, cwd, safe_env)
|
|
||||||
|
|
||||||
try:
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
*full_command,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
cwd=cwd,
|
|
||||||
env=safe_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
|
||||||
proc.communicate(), timeout=timeout
|
|
||||||
)
|
|
||||||
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
|
||||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
|
||||||
return stdout, stderr, proc.returncode or 0, False
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
proc.kill()
|
|
||||||
await proc.communicate()
|
|
||||||
return "", f"Execution timed out after {timeout}s", -1, True
|
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return "", f"Sandbox error: {e}", -1, False
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
"""Tests for BlockDetailsResponse in RunBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import BlockDetailsResponse
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
from backend.data.model import CredentialsMetaInput
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block-details"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block_with_inputs(
|
|
||||||
block_id: str, name: str, description: str = "Test description"
|
|
||||||
):
|
|
||||||
"""Create a mock block with input/output schemas for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.description = description
|
|
||||||
mock.block_type = BlockType.STANDARD
|
|
||||||
mock.disabled = False
|
|
||||||
|
|
||||||
# Input schema with non-credential fields
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {
|
|
||||||
"properties": {
|
|
||||||
"url": {"type": "string", "description": "URL to fetch"},
|
|
||||||
"method": {"type": "string", "description": "HTTP method"},
|
|
||||||
},
|
|
||||||
"required": ["url"],
|
|
||||||
}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
|
||||||
|
|
||||||
# Output schema
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = {
|
|
||||||
"properties": {
|
|
||||||
"response": {"type": "object", "description": "HTTP response"},
|
|
||||||
"error": {"type": "string", "description": "Error message"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_block_returns_details_when_no_input_provided():
|
|
||||||
"""When run_block is called without input_data, it should return BlockDetailsResponse."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Create a block with inputs
|
|
||||||
http_block = make_mock_block_with_inputs(
|
|
||||||
"http-block-id", "HTTP Request", "Send HTTP requests"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=http_block,
|
|
||||||
):
|
|
||||||
# Mock credentials check to return no missing credentials
|
|
||||||
with patch.object(
|
|
||||||
RunBlockTool,
|
|
||||||
"_resolve_block_credentials",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=({}, []), # (matched_credentials, missing_credentials)
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="http-block-id",
|
|
||||||
input_data={}, # Empty input data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return BlockDetailsResponse showing the schema
|
|
||||||
assert isinstance(response, BlockDetailsResponse)
|
|
||||||
assert response.block.id == "http-block-id"
|
|
||||||
assert response.block.name == "HTTP Request"
|
|
||||||
assert response.block.description == "Send HTTP requests"
|
|
||||||
assert "url" in response.block.inputs["properties"]
|
|
||||||
assert "method" in response.block.inputs["properties"]
|
|
||||||
assert "response" in response.block.outputs["properties"]
|
|
||||||
assert response.user_authenticated is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_block_returns_details_when_only_credentials_provided():
|
|
||||||
"""When only credentials are provided (no actual input), should return details."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Create a block with both credential and non-credential inputs
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = "api-block-id"
|
|
||||||
mock.name = "API Call"
|
|
||||||
mock.description = "Make API calls"
|
|
||||||
mock.block_type = BlockType.STANDARD
|
|
||||||
mock.disabled = False
|
|
||||||
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {
|
|
||||||
"properties": {
|
|
||||||
"credentials": {"type": "object", "description": "API credentials"},
|
|
||||||
"endpoint": {"type": "string", "description": "API endpoint"},
|
|
||||||
},
|
|
||||||
"required": ["credentials", "endpoint"],
|
|
||||||
}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {"credentials": True}
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
|
||||||
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = {
|
|
||||||
"properties": {"result": {"type": "object"}}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=mock,
|
|
||||||
):
|
|
||||||
with patch.object(
|
|
||||||
RunBlockTool,
|
|
||||||
"_resolve_block_credentials",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(
|
|
||||||
{
|
|
||||||
"credentials": CredentialsMetaInput(
|
|
||||||
id="cred-id",
|
|
||||||
provider=ProviderName("test_provider"),
|
|
||||||
type="api_key",
|
|
||||||
title="Test Credential",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
[],
|
|
||||||
),
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="api-block-id",
|
|
||||||
input_data={"credentials": {"some": "cred"}}, # Only credential
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should return details because no non-credential inputs provided
|
|
||||||
assert isinstance(response, BlockDetailsResponse)
|
|
||||||
assert response.block.id == "api-block-id"
|
|
||||||
assert response.block.name == "API Call"
|
|
||||||
@@ -118,7 +118,7 @@ def build_missing_credentials_from_graph(
|
|||||||
preserving all supported credential types for each field.
|
preserving all supported credential types for each field.
|
||||||
"""
|
"""
|
||||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||||
aggregated_fields = graph.aggregate_credentials_inputs()
|
aggregated_fields = graph.regular_credentials_inputs
|
||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
@@ -338,7 +338,7 @@ async def match_user_credentials_to_graph(
|
|||||||
missing_creds: list[str] = []
|
missing_creds: list[str] = []
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
aggregated_creds = graph.regular_credentials_inputs
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
"""Tests for chat tools utility functions."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
|
||||||
|
def _make_regular_field() -> CredentialsFieldInfo:
|
||||||
|
return CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_missing_credentials_excludes_auto_creds():
|
||||||
|
"""
|
||||||
|
build_missing_credentials_from_graph() should use regular_credentials_inputs
|
||||||
|
and thus exclude auto_credentials from the "missing" set.
|
||||||
|
"""
|
||||||
|
from backend.api.features.chat.tools.utils import (
|
||||||
|
build_missing_credentials_from_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
regular_field = _make_regular_field()
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
# regular_credentials_inputs should only return the non-auto field
|
||||||
|
mock_graph.regular_credentials_inputs = {
|
||||||
|
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
|
||||||
|
|
||||||
|
# Should include the regular credential
|
||||||
|
assert "github_api_key" in result
|
||||||
|
# Should NOT include the auto_credential (not in regular_credentials_inputs)
|
||||||
|
assert "google_oauth2" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_match_user_credentials_excludes_auto_creds():
|
||||||
|
"""
|
||||||
|
match_user_credentials_to_graph() should use regular_credentials_inputs
|
||||||
|
and thus exclude auto_credentials from matching.
|
||||||
|
"""
|
||||||
|
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
|
||||||
|
|
||||||
|
regular_field = _make_regular_field()
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.id = "test-graph"
|
||||||
|
# regular_credentials_inputs returns only non-auto fields
|
||||||
|
mock_graph.regular_credentials_inputs = {
|
||||||
|
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the credentials manager to return no credentials
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
|
||||||
|
) as MockCredsMgr:
|
||||||
|
mock_store = AsyncMock()
|
||||||
|
mock_store.get_all_creds.return_value = []
|
||||||
|
MockCredsMgr.return_value.store = mock_store
|
||||||
|
|
||||||
|
matched, missing = await match_user_credentials_to_graph(
|
||||||
|
user_id="test-user", graph=mock_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
# No credentials available, so github should be missing
|
||||||
|
assert len(matched) == 0
|
||||||
|
assert len(missing) == 1
|
||||||
|
assert "github_api_key" in missing[0]
|
||||||
@@ -1,151 +0,0 @@
|
|||||||
"""Web fetch tool — safely retrieve public web page content."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import html2text
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
WebFetchResponse,
|
|
||||||
)
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Limits
|
|
||||||
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
|
|
||||||
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
|
||||||
|
|
||||||
# Content types we'll read as text
|
|
||||||
_TEXT_CONTENT_TYPES = {
|
|
||||||
"text/html",
|
|
||||||
"text/plain",
|
|
||||||
"text/xml",
|
|
||||||
"text/csv",
|
|
||||||
"text/markdown",
|
|
||||||
"application/json",
|
|
||||||
"application/xml",
|
|
||||||
"application/xhtml+xml",
|
|
||||||
"application/rss+xml",
|
|
||||||
"application/atom+xml",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _is_text_content(content_type: str) -> bool:
|
|
||||||
base = content_type.split(";")[0].strip().lower()
|
|
||||||
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
|
|
||||||
|
|
||||||
|
|
||||||
def _html_to_text(html: str) -> str:
|
|
||||||
h = html2text.HTML2Text()
|
|
||||||
h.ignore_links = False
|
|
||||||
h.ignore_images = True
|
|
||||||
h.body_width = 0
|
|
||||||
return h.handle(html)
|
|
||||||
|
|
||||||
|
|
||||||
class WebFetchTool(BaseTool):
|
|
||||||
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "web_fetch"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Fetch the content of a public web page by URL. "
|
|
||||||
"Returns readable text extracted from HTML by default. "
|
|
||||||
"Useful for reading documentation, articles, and API responses. "
|
|
||||||
"Only supports HTTP/HTTPS GET requests to public URLs "
|
|
||||||
"(private/internal network addresses are blocked)."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The public HTTP/HTTPS URL to fetch.",
|
|
||||||
},
|
|
||||||
"extract_text": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"If true (default), extract readable text from HTML. "
|
|
||||||
"If false, return raw content."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["url"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
url: str = (kwargs.get("url") or "").strip()
|
|
||||||
extract_text: bool = kwargs.get("extract_text", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not url:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide a URL to fetch.",
|
|
||||||
error="missing_url",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = Requests(raise_for_status=False, retry_max_attempts=1)
|
|
||||||
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
|
|
||||||
except ValueError as e:
|
|
||||||
# validate_url raises ValueError for SSRF / blocked IPs
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"URL blocked: {e}",
|
|
||||||
error="url_blocked",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to fetch URL: {e}",
|
|
||||||
error="fetch_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
content_type = response.headers.get("content-type", "")
|
|
||||||
if not _is_text_content(content_type):
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Non-text content type: {content_type.split(';')[0]}",
|
|
||||||
error="unsupported_content_type",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
raw = response.content[:_MAX_CONTENT_BYTES]
|
|
||||||
text = raw.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
if extract_text and "html" in content_type.lower():
|
|
||||||
text = _html_to_text(text)
|
|
||||||
|
|
||||||
return WebFetchResponse(
|
|
||||||
message=f"Fetched {url}",
|
|
||||||
url=response.url,
|
|
||||||
status_code=response.status,
|
|
||||||
content_type=content_type.split(";")[0].strip(),
|
|
||||||
content=text,
|
|
||||||
truncated=False,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -88,9 +88,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"List files in the user's persistent workspace (cloud storage). "
|
"List files in the user's workspace. "
|
||||||
"These files survive across sessions. "
|
|
||||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
|
||||||
"Returns file names, paths, sizes, and metadata. "
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
"Optionally filter by path prefix."
|
"Optionally filter by path prefix."
|
||||||
)
|
)
|
||||||
@@ -206,9 +204,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Read a file from the user's persistent workspace (cloud storage). "
|
"Read a file from the user's workspace. "
|
||||||
"These files survive across sessions. "
|
|
||||||
"For ephemeral session files, use the SDK Read tool instead. "
|
|
||||||
"Specify either file_id or path to identify the file. "
|
"Specify either file_id or path to identify the file. "
|
||||||
"For small text files, returns content directly. "
|
"For small text files, returns content directly. "
|
||||||
"For large or binary files, returns metadata and a download URL. "
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
@@ -382,9 +378,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
"Write or create a file in the user's workspace. "
|
||||||
"These files survive across sessions. "
|
|
||||||
"For ephemeral session files, use the SDK Write tool instead. "
|
|
||||||
"Provide the content as a base64-encoded string. "
|
"Provide the content as a base64-encoded string. "
|
||||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
"Files are saved to the current session's folder by default. "
|
"Files are saved to the current session's folder by default. "
|
||||||
@@ -529,7 +523,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
"Delete a file from the user's workspace. "
|
||||||
"Specify either file_id or path to identify the file. "
|
"Specify either file_id or path to identify the file. "
|
||||||
"Paths are scoped to the current session by default. "
|
"Paths are scoped to the current session by default. "
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
|||||||
@@ -1102,7 +1102,7 @@ async def create_preset_from_graph_execution(
|
|||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||||
)
|
)
|
||||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
elif len(graph.regular_credentials_inputs) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||||
"because it was run before this feature existed "
|
"because it was run before this feature existed "
|
||||||
|
|||||||
@@ -309,6 +309,8 @@ class BlockSchema(BaseModel):
|
|||||||
"credentials_provider": [config.get("provider", "google")],
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
"credentials_types": [config.get("type", "oauth2")],
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
"credentials_scopes": config.get("scopes"),
|
"credentials_scopes": config.get("scopes"),
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": info["field_name"],
|
||||||
}
|
}
|
||||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
auto_schema, by_alias=True
|
auto_schema, by_alias=True
|
||||||
|
|||||||
@@ -434,8 +434,7 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
graph_credentials_inputs = self.regular_credentials_inputs
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -591,6 +590,28 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
for key, (field_info, node_field_pairs) in combined.items()
|
for key, (field_info, node_field_pairs) in combined.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def regular_credentials_inputs(
|
||||||
|
self,
|
||||||
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||||
|
"""Credentials that need explicit user mapping (CredentialsMetaInput fields)."""
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in self.aggregate_credentials_inputs().items()
|
||||||
|
if not v[0].is_auto_credential
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auto_credentials_inputs(
|
||||||
|
self,
|
||||||
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||||
|
"""Credentials embedded in file fields (_credentials_id), resolved at execution time."""
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in self.aggregate_credentials_inputs().items()
|
||||||
|
if v[0].is_auto_credential
|
||||||
|
}
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
Reassigns all IDs in the graph to new UUIDs.
|
Reassigns all IDs in the graph to new UUIDs.
|
||||||
@@ -641,6 +662,16 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
) and graph_id in graph_id_map:
|
) and graph_id in graph_id_map:
|
||||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||||
|
|
||||||
|
# Clear auto-credentials references (e.g., _credentials_id in
|
||||||
|
# GoogleDriveFile fields) so the new user must re-authenticate
|
||||||
|
# with their own account
|
||||||
|
for node in graph.nodes:
|
||||||
|
if not node.input_default:
|
||||||
|
continue
|
||||||
|
for key, value in node.input_default.items():
|
||||||
|
if isinstance(value, dict) and "_credentials_id" in value:
|
||||||
|
del value["_credentials_id"]
|
||||||
|
|
||||||
def validate_graph(
|
def validate_graph(
|
||||||
self,
|
self,
|
||||||
for_run: bool = False,
|
for_run: bool = False,
|
||||||
|
|||||||
@@ -462,3 +462,329 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for CredentialsFieldInfo.combine() field propagation
|
||||||
|
def test_combine_preserves_is_auto_credential_flag():
|
||||||
|
"""
|
||||||
|
CredentialsFieldInfo.combine() must propagate is_auto_credential and
|
||||||
|
input_field_name to the combined result. Regression test for reviewer
|
||||||
|
finding that combine() dropped these fields.
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
auto_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google"],
|
||||||
|
"credentials_types": ["oauth2"],
|
||||||
|
"credentials_scopes": ["drive.readonly"],
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": "spreadsheet",
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# combine() takes *args of (field_info, key) tuples
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(auto_field, ("node-1", "credentials")),
|
||||||
|
(auto_field, ("node-2", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(combined) == 1
|
||||||
|
group_key = next(iter(combined))
|
||||||
|
combined_info, combined_keys = combined[group_key]
|
||||||
|
|
||||||
|
assert combined_info.is_auto_credential is True
|
||||||
|
assert combined_info.input_field_name == "spreadsheet"
|
||||||
|
assert combined_keys == {("node-1", "credentials"), ("node-2", "credentials")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_combine_preserves_regular_credential_defaults():
|
||||||
|
"""Regular credentials should have is_auto_credential=False after combine()."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
regular_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(regular_field, ("node-1", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
group_key = next(iter(combined))
|
||||||
|
combined_info, _ = combined[group_key]
|
||||||
|
|
||||||
|
assert combined_info.is_auto_credential is False
|
||||||
|
assert combined_info.input_field_name is None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for _reassign_ids credential clearing (Fix 3: SECRT-1772)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_clears_credentials_id():
|
||||||
|
"""
|
||||||
|
[SECRT-1772] _reassign_ids should clear _credentials_id from
|
||||||
|
GoogleDriveFile-style input_default fields so forked agents
|
||||||
|
don't retain the original creator's credential references.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "original-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
# _credentials_id key should be removed (not set to None) so that
|
||||||
|
# _acquire_auto_credentials correctly errors instead of treating it as chained data
|
||||||
|
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_preserves_non_credential_fields():
|
||||||
|
"""
|
||||||
|
Regression guard: _reassign_ids should NOT modify non-credential fields
|
||||||
|
like name, mimeType, id, url.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "cred-abc",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
field = graph.nodes[0].input_default["spreadsheet"]
|
||||||
|
assert field["id"] == "file-123"
|
||||||
|
assert field["name"] == "test.xlsx"
|
||||||
|
assert field["mimeType"] == "application/vnd.google-apps.spreadsheet"
|
||||||
|
assert field["url"] == "https://docs.google.com/spreadsheets/d/file-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_handles_no_credentials():
|
||||||
|
"""
|
||||||
|
Regression guard: _reassign_ids should not error when input_default
|
||||||
|
has no dict fields with _credentials_id.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"input": "some value",
|
||||||
|
"another_input": 42,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
# Should not error, fields unchanged
|
||||||
|
assert graph.nodes[0].input_default["input"] == "some value"
|
||||||
|
assert graph.nodes[0].input_default["another_input"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_reassign_ids_handles_multiple_credential_fields():
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When a node has multiple dict fields with _credentials_id,
|
||||||
|
ALL of them should be cleared.
|
||||||
|
"""
|
||||||
|
from backend.data.graph import GraphModel
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "cred-1",
|
||||||
|
"id": "file-1",
|
||||||
|
"name": "file1.xlsx",
|
||||||
|
},
|
||||||
|
"doc_file": {
|
||||||
|
"_credentials_id": "cred-2",
|
||||||
|
"id": "file-2",
|
||||||
|
"name": "file2.docx",
|
||||||
|
},
|
||||||
|
"plain_input": "not a dict",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = Graph(
|
||||||
|
id="test-graph",
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
nodes=[node],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||||
|
|
||||||
|
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||||
|
assert "_credentials_id" not in graph.nodes[0].input_default["doc_file"]
|
||||||
|
assert graph.nodes[0].input_default["plain_input"] == "not a dict"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for discriminate() field propagation
|
||||||
|
def test_discriminate_preserves_is_auto_credential_flag():
|
||||||
|
"""
|
||||||
|
CredentialsFieldInfo.discriminate() must propagate is_auto_credential and
|
||||||
|
input_field_name to the discriminated result. Regression test for
|
||||||
|
discriminate() dropping these fields (same class of bug as combine()).
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
auto_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google", "openai"],
|
||||||
|
"credentials_types": ["oauth2"],
|
||||||
|
"credentials_scopes": ["drive.readonly"],
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": "spreadsheet",
|
||||||
|
"discriminator": "model",
|
||||||
|
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminated = auto_field.discriminate("gemini")
|
||||||
|
|
||||||
|
assert discriminated.is_auto_credential is True
|
||||||
|
assert discriminated.input_field_name == "spreadsheet"
|
||||||
|
assert discriminated.provider == frozenset(["google"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_discriminate_preserves_regular_credential_defaults():
|
||||||
|
"""Regular credentials should have is_auto_credential=False after discriminate()."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
regular_field = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google", "openai"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
"discriminator": "model",
|
||||||
|
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
discriminated = regular_field.discriminate("gpt-4")
|
||||||
|
|
||||||
|
assert discriminated.is_auto_credential is False
|
||||||
|
assert discriminated.input_field_name is None
|
||||||
|
assert discriminated.provider == frozenset(["openai"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for credentials_input_schema excluding auto_credentials
|
||||||
|
def test_credentials_input_schema_excludes_auto_creds():
|
||||||
|
"""
|
||||||
|
GraphModel.credentials_input_schema should exclude auto_credentials
|
||||||
|
(is_auto_credential=True) from the schema. Auto_credentials are
|
||||||
|
transparently resolved at execution time via file picker data.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
|
from backend.data.graph import GraphModel, NodeModel
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = GraphModel(
|
||||||
|
id="test-graph",
|
||||||
|
version=1,
|
||||||
|
name="Test",
|
||||||
|
description="Test",
|
||||||
|
user_id="test-user",
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
nodes=[
|
||||||
|
NodeModel(
|
||||||
|
id="node-1",
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={},
|
||||||
|
graph_id="test-graph",
|
||||||
|
graph_version=1,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock regular_credentials_inputs to return only the non-auto field (3-tuple)
|
||||||
|
regular_only = {
|
||||||
|
"github_credentials": (
|
||||||
|
regular_field_info,
|
||||||
|
{("node-1", "credentials")},
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
type(graph),
|
||||||
|
"regular_credentials_inputs",
|
||||||
|
new_callable=PropertyMock,
|
||||||
|
return_value=regular_only,
|
||||||
|
):
|
||||||
|
schema = graph.credentials_input_schema
|
||||||
|
field_names = set(schema.get("properties", {}).keys())
|
||||||
|
# Should include regular credential but NOT auto_credential
|
||||||
|
assert "github_credentials" in field_names
|
||||||
|
assert "google_credentials" not in field_names
|
||||||
|
|||||||
@@ -574,6 +574,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator: Optional[str] = None
|
discriminator: Optional[str] = None
|
||||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||||
discriminator_values: set[Any] = Field(default_factory=set)
|
discriminator_values: set[Any] = Field(default_factory=set)
|
||||||
|
is_auto_credential: bool = False
|
||||||
|
input_field_name: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def combine(
|
def combine(
|
||||||
@@ -654,6 +656,9 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
+ "_credentials"
|
+ "_credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Propagate is_auto_credential from the combined field.
|
||||||
|
# All fields in a group should share the same is_auto_credential
|
||||||
|
# value since auto and regular credentials serve different purposes.
|
||||||
result[group_key] = (
|
result[group_key] = (
|
||||||
CredentialsFieldInfo[CP, CT](
|
CredentialsFieldInfo[CP, CT](
|
||||||
credentials_provider=combined.provider,
|
credentials_provider=combined.provider,
|
||||||
@@ -662,6 +667,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator=combined.discriminator,
|
discriminator=combined.discriminator,
|
||||||
discriminator_mapping=combined.discriminator_mapping,
|
discriminator_mapping=combined.discriminator_mapping,
|
||||||
discriminator_values=set(all_discriminator_values),
|
discriminator_values=set(all_discriminator_values),
|
||||||
|
is_auto_credential=combined.is_auto_credential,
|
||||||
|
input_field_name=combined.input_field_name,
|
||||||
),
|
),
|
||||||
combined_keys,
|
combined_keys,
|
||||||
)
|
)
|
||||||
@@ -687,6 +694,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
discriminator_mapping=self.discriminator_mapping,
|
discriminator_mapping=self.discriminator_mapping,
|
||||||
discriminator_values=self.discriminator_values,
|
discriminator_values=self.discriminator_values,
|
||||||
|
is_auto_credential=self.is_auto_credential,
|
||||||
|
input_field_name=self.input_field_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,81 @@ def execute_graph(
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
async def _acquire_auto_credentials(
|
||||||
|
input_model: type[BlockSchema],
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
creds_manager: "IntegrationCredentialsManager",
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[dict[str, Any], list[AsyncRedisLock]]:
|
||||||
|
"""
|
||||||
|
Resolve auto_credentials from GoogleDriveFileField-style inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(extra_exec_kwargs, locks): kwargs to inject into block execution, and
|
||||||
|
credential locks to release after execution completes.
|
||||||
|
"""
|
||||||
|
extra_exec_kwargs: dict[str, Any] = {}
|
||||||
|
locks: list[AsyncRedisLock] = []
|
||||||
|
|
||||||
|
# NOTE: If a block ever has multiple auto-credential fields, a ValueError
|
||||||
|
# on a later field will strand locks acquired for earlier fields. They'll
|
||||||
|
# auto-expire via Redis TTL, but add a try/except to release partial locks
|
||||||
|
# if that becomes a real scenario.
|
||||||
|
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||||
|
field_name = info["field_name"]
|
||||||
|
field_data = input_data.get(field_name)
|
||||||
|
|
||||||
|
if field_data and isinstance(field_data, dict):
|
||||||
|
# Check if _credentials_id key exists in the field data
|
||||||
|
if "_credentials_id" in field_data:
|
||||||
|
cred_id = field_data["_credentials_id"]
|
||||||
|
if cred_id:
|
||||||
|
# Credential ID provided - acquire credentials
|
||||||
|
provider = info.get("config", {}).get(
|
||||||
|
"provider", "external service"
|
||||||
|
)
|
||||||
|
file_name = field_data.get("name", "selected file")
|
||||||
|
try:
|
||||||
|
credentials, lock = await creds_manager.acquire(
|
||||||
|
user_id, cred_id
|
||||||
|
)
|
||||||
|
locks.append(lock)
|
||||||
|
extra_exec_kwargs[kwarg_name] = credentials
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"{provider.capitalize()} credentials for "
|
||||||
|
f"'{file_name}' in field '{field_name}' are not "
|
||||||
|
f"available in your account. "
|
||||||
|
f"This can happen if the agent was created by another "
|
||||||
|
f"user or the credentials were deleted. "
|
||||||
|
f"Please open the agent in the builder and re-select "
|
||||||
|
f"the file to authenticate with your own account."
|
||||||
|
)
|
||||||
|
# else: _credentials_id is explicitly None, skip (chained data)
|
||||||
|
else:
|
||||||
|
# _credentials_id key missing entirely - this is an error
|
||||||
|
provider = info.get("config", {}).get("provider", "external service")
|
||||||
|
file_name = field_data.get("name", "selected file")
|
||||||
|
raise ValueError(
|
||||||
|
f"Authentication missing for '{file_name}' in field "
|
||||||
|
f"'{field_name}'. Please re-select the file to authenticate "
|
||||||
|
f"with {provider.capitalize()}."
|
||||||
|
)
|
||||||
|
elif field_data is None and field_name not in input_data:
|
||||||
|
# Field not in input_data at all = connected from upstream block, skip
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# field_data is None/empty but key IS in input_data = user didn't select
|
||||||
|
provider = info.get("config", {}).get("provider", "external service")
|
||||||
|
raise ValueError(
|
||||||
|
f"No file selected for '{field_name}'. "
|
||||||
|
f"Please select a file to provide "
|
||||||
|
f"{provider.capitalize()} authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
return extra_exec_kwargs, locks
|
||||||
|
|
||||||
|
|
||||||
async def execute_node(
|
async def execute_node(
|
||||||
node: Node,
|
node: Node,
|
||||||
data: NodeExecutionEntry,
|
data: NodeExecutionEntry,
|
||||||
@@ -270,41 +345,14 @@ async def execute_node(
|
|||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
auto_extra_kwargs, auto_locks = await _acquire_auto_credentials(
|
||||||
field_name = info["field_name"]
|
input_model=input_model,
|
||||||
field_data = input_data.get(field_name)
|
input_data=input_data,
|
||||||
if field_data and isinstance(field_data, dict):
|
creds_manager=creds_manager,
|
||||||
# Check if _credentials_id key exists in the field data
|
user_id=user_id,
|
||||||
if "_credentials_id" in field_data:
|
)
|
||||||
cred_id = field_data["_credentials_id"]
|
extra_exec_kwargs.update(auto_extra_kwargs)
|
||||||
if cred_id:
|
creds_locks.extend(auto_locks)
|
||||||
# Credential ID provided - acquire credentials
|
|
||||||
provider = info.get("config", {}).get(
|
|
||||||
"provider", "external service"
|
|
||||||
)
|
|
||||||
file_name = field_data.get("name", "selected file")
|
|
||||||
try:
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, cred_id
|
|
||||||
)
|
|
||||||
creds_locks.append(lock)
|
|
||||||
extra_exec_kwargs[kwarg_name] = credentials
|
|
||||||
except ValueError:
|
|
||||||
# Credential was deleted or doesn't exist
|
|
||||||
raise ValueError(
|
|
||||||
f"Authentication expired for '{file_name}' in field '{field_name}'. "
|
|
||||||
f"The saved {provider.capitalize()} credentials no longer exist. "
|
|
||||||
f"Please re-select the file to re-authenticate."
|
|
||||||
)
|
|
||||||
# else: _credentials_id is explicitly None, skip credentials (for chained data)
|
|
||||||
else:
|
|
||||||
# _credentials_id key missing entirely - this is an error
|
|
||||||
provider = info.get("config", {}).get("provider", "external service")
|
|
||||||
file_name = field_data.get("name", "selected file")
|
|
||||||
raise ValueError(
|
|
||||||
f"Authentication missing for '{file_name}' in field '{field_name}'. "
|
|
||||||
f"Please re-select the file to authenticate with {provider.capitalize()}."
|
|
||||||
)
|
|
||||||
|
|
||||||
output_size = 0
|
output_size = 0
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,320 @@
|
|||||||
|
"""
|
||||||
|
Tests for auto_credentials handling in execute_node().
|
||||||
|
|
||||||
|
These test the _acquire_auto_credentials() helper function extracted from
|
||||||
|
execute_node() (manager.py lines 273-308).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def google_drive_file_data():
|
||||||
|
return {
|
||||||
|
"valid": {
|
||||||
|
"_credentials_id": "cred-id-123",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
},
|
||||||
|
"chained": {
|
||||||
|
"_credentials_id": None,
|
||||||
|
"id": "file-456",
|
||||||
|
"name": "chained.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
},
|
||||||
|
"missing_key": {
|
||||||
|
"id": "file-789",
|
||||||
|
"name": "bad.xlsx",
|
||||||
|
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_input_model(mocker: MockerFixture):
|
||||||
|
"""Create a mock input model with get_auto_credentials_fields() returning one field."""
|
||||||
|
input_model = mocker.MagicMock()
|
||||||
|
input_model.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {
|
||||||
|
"provider": "google",
|
||||||
|
"type": "oauth2",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/drive.readonly"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_creds_manager(mocker: MockerFixture):
|
||||||
|
manager = mocker.AsyncMock()
|
||||||
|
mock_lock = mocker.AsyncMock()
|
||||||
|
mock_creds = mocker.MagicMock()
|
||||||
|
mock_creds.id = "cred-id-123"
|
||||||
|
mock_creds.provider = "google"
|
||||||
|
manager.acquire.return_value = (mock_creds, mock_lock)
|
||||||
|
return manager, mock_creds, mock_lock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_happy_path(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""When field_data has a valid _credentials_id, credentials should be acquired."""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, mock_creds, mock_lock = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||||
|
assert extra_kwargs["credentials"] == mock_creds
|
||||||
|
assert mock_lock in locks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_field_none_static_raises(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[THE BUG FIX TEST — OPEN-2895]
|
||||||
|
When field_data is None and the key IS in input_data (user didn't select a file),
|
||||||
|
should raise ValueError instead of silently skipping.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
# Key is present but value is None = user didn't select a file
|
||||||
|
input_data = {"spreadsheet": None}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No file selected"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_field_absent_skips(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When the field key is NOT in input_data at all (upstream connection),
|
||||||
|
should skip without error.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
# Key not present = connected from upstream block
|
||||||
|
input_data = {}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.acquire.assert_not_called()
|
||||||
|
assert "credentials" not in extra_kwargs
|
||||||
|
assert locks == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_chained_cred_id_none(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When _credentials_id is explicitly None (chained data from upstream),
|
||||||
|
should skip credential acquisition.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["chained"]}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.acquire.assert_not_called()
|
||||||
|
assert "credentials" not in extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_missing_cred_id_key_raises(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When _credentials_id key is missing entirely from field_data dict,
|
||||||
|
should raise ValueError.
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["missing_key"]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Authentication missing"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_ownership_mismatch_error(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When acquire() raises ValueError (credential belongs to another user),
|
||||||
|
the error message should mention 'not available' (not 'expired').
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
manager.acquire.side_effect = ValueError(
|
||||||
|
"Credentials #cred-id-123 for user #user-2 not found"
|
||||||
|
)
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not available in your account"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_deleted_credential_error(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When acquire() raises ValueError (credential was deleted),
|
||||||
|
the error message should mention 'not available' (not 'expired').
|
||||||
|
"""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, _ = mock_creds_manager
|
||||||
|
manager.acquire.side_effect = ValueError(
|
||||||
|
"Credentials #cred-id-123 for user #user-1 not found"
|
||||||
|
)
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not available in your account"):
|
||||||
|
await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_lock_appended(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
google_drive_file_data,
|
||||||
|
mock_input_model,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""Lock from acquire() should be included in returned locks list."""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, _, mock_lock = mock_creds_manager
|
||||||
|
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=mock_input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(locks) == 1
|
||||||
|
assert locks[0] is mock_lock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_credentials_multiple_fields(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
mock_creds_manager,
|
||||||
|
):
|
||||||
|
"""When there are multiple auto_credentials fields, only valid ones should acquire."""
|
||||||
|
from backend.executor.manager import _acquire_auto_credentials
|
||||||
|
|
||||||
|
manager, mock_creds, mock_lock = mock_creds_manager
|
||||||
|
|
||||||
|
input_model = mocker.MagicMock()
|
||||||
|
input_model.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
},
|
||||||
|
"credentials2": {
|
||||||
|
"field_name": "doc_file",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
input_data = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "cred-id-123",
|
||||||
|
"id": "file-1",
|
||||||
|
"name": "file1.xlsx",
|
||||||
|
},
|
||||||
|
"doc_file": {
|
||||||
|
"_credentials_id": None,
|
||||||
|
"id": "file-2",
|
||||||
|
"name": "chained.doc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||||
|
input_model=input_model,
|
||||||
|
input_data=input_data,
|
||||||
|
creds_manager=manager,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only the first field should have acquired credentials
|
||||||
|
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||||
|
assert "credentials" in extra_kwargs
|
||||||
|
assert "credentials2" not in extra_kwargs
|
||||||
|
assert len(locks) == 1
|
||||||
@@ -254,7 +254,8 @@ async def _validate_node_input_credentials(
|
|||||||
|
|
||||||
# Find any fields of type CredentialsMetaInput
|
# Find any fields of type CredentialsMetaInput
|
||||||
credentials_fields = block.input_schema.get_credentials_fields()
|
credentials_fields = block.input_schema.get_credentials_fields()
|
||||||
if not credentials_fields:
|
auto_credentials_fields = block.input_schema.get_auto_credentials_fields()
|
||||||
|
if not credentials_fields and not auto_credentials_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Track if any credential field is missing for this node
|
# Track if any credential field is missing for this node
|
||||||
@@ -334,6 +335,47 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Validate auto-credentials (GoogleDriveFileField-based)
|
||||||
|
# These have _credentials_id embedded in the file field data
|
||||||
|
if auto_credentials_fields:
|
||||||
|
for _kwarg_name, info in auto_credentials_fields.items():
|
||||||
|
field_name = info["field_name"]
|
||||||
|
# Check input_default and nodes_input_masks for the field value
|
||||||
|
field_value = node.input_default.get(field_name)
|
||||||
|
if nodes_input_masks and node.id in nodes_input_masks:
|
||||||
|
field_value = nodes_input_masks[node.id].get(
|
||||||
|
field_name, field_value
|
||||||
|
)
|
||||||
|
|
||||||
|
if field_value and isinstance(field_value, dict):
|
||||||
|
if "_credentials_id" not in field_value:
|
||||||
|
# Key removed (e.g., on fork) — needs re-auth
|
||||||
|
has_missing_credentials = True
|
||||||
|
credential_errors[node.id][field_name] = (
|
||||||
|
"Authentication missing for the selected file. "
|
||||||
|
"Please re-select the file to authenticate with "
|
||||||
|
"your own account."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
cred_id = field_value.get("_credentials_id")
|
||||||
|
if cred_id and isinstance(cred_id, str):
|
||||||
|
try:
|
||||||
|
creds_store = get_integration_credentials_store()
|
||||||
|
creds = await creds_store.get_creds_by_id(user_id, cred_id)
|
||||||
|
except Exception as e:
|
||||||
|
has_missing_credentials = True
|
||||||
|
credential_errors[node.id][
|
||||||
|
field_name
|
||||||
|
] = f"Credentials not available: {e}"
|
||||||
|
continue
|
||||||
|
if not creds:
|
||||||
|
has_missing_credentials = True
|
||||||
|
credential_errors[node.id][field_name] = (
|
||||||
|
"The saved credentials are not available "
|
||||||
|
"for your account. Please re-select the file to "
|
||||||
|
"authenticate with your own account."
|
||||||
|
)
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, mark for skipping
|
# If node has optional credentials and any are missing, mark for skipping
|
||||||
# But only if there are no other errors for this node
|
# But only if there are no other errors for this node
|
||||||
if (
|
if (
|
||||||
@@ -365,8 +407,9 @@ def make_node_credentials_input_map(
|
|||||||
"""
|
"""
|
||||||
result: dict[str, dict[str, JsonValue]] = {}
|
result: dict[str, dict[str, JsonValue]] = {}
|
||||||
|
|
||||||
# Get aggregated credentials fields for the graph
|
# Only map regular credentials (not auto_credentials, which are resolved
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
# at execution time from _credentials_id in file field data)
|
||||||
|
graph_cred_inputs = graph.regular_credentials_inputs
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
|
|||||||
@@ -907,3 +907,335 @@ async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
|||||||
|
|
||||||
# Verify both parent and child status updates
|
# Verify both parent and child status updates
|
||||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for auto_credentials validation in _validate_node_input_credentials
|
||||||
|
# (Fix 3: SECRT-1772 + Fix 4: Path 4)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_auto_creds_valid(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When a node has auto_credentials with a valid _credentials_id
|
||||||
|
that exists in the store, validation should pass without errors.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-auto-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "valid-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
# No regular credentials fields
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
# Has auto_credentials fields
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
# Mock the credentials store to return valid credentials
|
||||||
|
mock_store = mocker.MagicMock()
|
||||||
|
mock_creds = mocker.MagicMock()
|
||||||
|
mock_creds.id = "valid-cred-id"
|
||||||
|
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=mock_creds)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.executor.utils.get_integration_credentials_store",
|
||||||
|
return_value=mock_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="test-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_node.id not in errors
|
||||||
|
assert mock_node.id not in nodes_to_skip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_auto_creds_missing(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] When a node has auto_credentials with a _credentials_id
|
||||||
|
that doesn't exist for the current user, validation should report an error.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-bad-auto-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "other-users-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
# Mock the credentials store to return None (cred not found for this user)
|
||||||
|
mock_store = mocker.MagicMock()
|
||||||
|
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=None)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.executor.utils.get_integration_credentials_store",
|
||||||
|
return_value=mock_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="different-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_node.id in errors
|
||||||
|
assert "spreadsheet" in errors[mock_node.id]
|
||||||
|
assert "not available" in errors[mock_node.id]["spreadsheet"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_both_regular_and_auto(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[SECRT-1772] A node that has BOTH regular credentials AND auto_credentials
|
||||||
|
should have both validated.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-both-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"credentials": {
|
||||||
|
"id": "regular-cred-id",
|
||||||
|
"provider": "github",
|
||||||
|
"type": "api_key",
|
||||||
|
},
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": "auto-cred-id",
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_credentials_field_type = mocker.MagicMock()
|
||||||
|
mock_credentials_meta = mocker.MagicMock()
|
||||||
|
mock_credentials_meta.id = "regular-cred-id"
|
||||||
|
mock_credentials_meta.provider = "github"
|
||||||
|
mock_credentials_meta.type = "api_key"
|
||||||
|
mock_credentials_field_type.model_validate.return_value = mock_credentials_meta
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
# Regular credentials field
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
|
"credentials": mock_credentials_field_type,
|
||||||
|
}
|
||||||
|
# Auto-credentials field
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"auto_credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
# Mock the credentials store to return valid credentials for both
|
||||||
|
mock_store = mocker.MagicMock()
|
||||||
|
mock_regular_creds = mocker.MagicMock()
|
||||||
|
mock_regular_creds.id = "regular-cred-id"
|
||||||
|
mock_regular_creds.provider = "github"
|
||||||
|
mock_regular_creds.type = "api_key"
|
||||||
|
|
||||||
|
mock_auto_creds = mocker.MagicMock()
|
||||||
|
mock_auto_creds.id = "auto-cred-id"
|
||||||
|
|
||||||
|
def get_creds_side_effect(user_id, cred_id):
|
||||||
|
if cred_id == "regular-cred-id":
|
||||||
|
return mock_regular_creds
|
||||||
|
elif cred_id == "auto-cred-id":
|
||||||
|
return mock_auto_creds
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_store.get_creds_by_id = mocker.AsyncMock(side_effect=get_creds_side_effect)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.executor.utils.get_integration_credentials_store",
|
||||||
|
return_value=mock_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="test-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should validate successfully - no errors
|
||||||
|
assert mock_node.id not in errors
|
||||||
|
assert mock_node.id not in nodes_to_skip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_node_input_credentials_auto_creds_skipped_when_none(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When a node has auto_credentials but the field value has _credentials_id=None
|
||||||
|
(e.g., from upstream connection), validation should skip it without error.
|
||||||
|
"""
|
||||||
|
from backend.executor.utils import _validate_node_input_credentials
|
||||||
|
|
||||||
|
mock_node = mocker.MagicMock()
|
||||||
|
mock_node.id = "node-with-chained-auto-creds"
|
||||||
|
mock_node.credentials_optional = False
|
||||||
|
mock_node.input_default = {
|
||||||
|
"spreadsheet": {
|
||||||
|
"_credentials_id": None,
|
||||||
|
"id": "file-123",
|
||||||
|
"name": "test.xlsx",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_block = mocker.MagicMock()
|
||||||
|
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||||
|
"credentials": {
|
||||||
|
"field_name": "spreadsheet",
|
||||||
|
"config": {"provider": "google", "type": "oauth2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mock_node.block = mock_block
|
||||||
|
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
mock_graph.nodes = [mock_node]
|
||||||
|
|
||||||
|
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||||
|
graph=mock_graph,
|
||||||
|
user_id="test-user",
|
||||||
|
nodes_input_masks=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No error - chained data with None cred_id is valid
|
||||||
|
assert mock_node.id not in errors
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for CredentialsFieldInfo auto_credential tag (Fix 4: Path 4)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_credentials_field_info_auto_credential_tag():
|
||||||
|
"""
|
||||||
|
[Path 4] CredentialsFieldInfo should support is_auto_credential and
|
||||||
|
input_field_name fields for distinguishing auto from regular credentials.
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo
|
||||||
|
|
||||||
|
# Regular credential should have is_auto_credential=False by default
|
||||||
|
regular = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
assert regular.is_auto_credential is False
|
||||||
|
assert regular.input_field_name is None
|
||||||
|
|
||||||
|
# Auto credential should have is_auto_credential=True
|
||||||
|
auto = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["google"],
|
||||||
|
"credentials_types": ["oauth2"],
|
||||||
|
"is_auto_credential": True,
|
||||||
|
"input_field_name": "spreadsheet",
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
assert auto.is_auto_credential is True
|
||||||
|
assert auto.input_field_name == "spreadsheet"
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_node_credentials_input_map_excludes_auto_creds(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[Path 4] make_node_credentials_input_map should only include regular credentials,
|
||||||
|
not auto_credentials (which are resolved at execution time).
|
||||||
|
"""
|
||||||
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
|
from backend.executor.utils import make_node_credentials_input_map
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
# Create a mock graph with aggregate_credentials_inputs that returns
|
||||||
|
# both regular and auto credentials
|
||||||
|
mock_graph = mocker.MagicMock()
|
||||||
|
|
||||||
|
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||||
|
{
|
||||||
|
"credentials_provider": ["github"],
|
||||||
|
"credentials_types": ["api_key"],
|
||||||
|
"is_auto_credential": False,
|
||||||
|
},
|
||||||
|
by_alias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock regular_credentials_inputs property (auto_credentials are excluded)
|
||||||
|
mock_graph.regular_credentials_inputs = {
|
||||||
|
"github_creds": (regular_field_info, {("node-1", "credentials")}, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_credentials_input = {
|
||||||
|
"github_creds": CredentialsMetaInput(
|
||||||
|
id="cred-123",
|
||||||
|
provider=ProviderName("github"),
|
||||||
|
type="api_key",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = make_node_credentials_input_map(mock_graph, graph_credentials_input)
|
||||||
|
|
||||||
|
# Regular credentials should be mapped
|
||||||
|
assert "node-1" in result
|
||||||
|
assert "credentials" in result["node-1"]
|
||||||
|
|
||||||
|
# Auto credentials should NOT appear in the result
|
||||||
|
# (they would have been mapped to the kwarg_name "credentials" not "spreadsheet")
|
||||||
|
for node_id, fields in result.items():
|
||||||
|
for field_name, value in fields.items():
|
||||||
|
# Verify no auto-credential phantom entries
|
||||||
|
if isinstance(value, dict):
|
||||||
|
assert "_credentials_id" not in value
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ class Flag(str, Enum):
|
|||||||
AGENT_ACTIVITY = "agent-activity"
|
AGENT_ACTIVITY = "agent-activity"
|
||||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
COPILOT_SDK = "copilot-sdk"
|
|
||||||
|
|
||||||
|
|
||||||
def is_configured() -> bool:
|
def is_configured() -> bool:
|
||||||
|
|||||||
160
autogpt_platform/backend/poetry.lock
generated
160
autogpt_platform/backend/poetry.lock
generated
@@ -441,14 +441,14 @@ develop = true
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.7"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.15.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = {version = "^2.11.0", extras = ["crypto"]}
|
pyjwt = {version = "^2.11.0", extras = ["crypto"]}
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.28.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
@@ -897,29 +897,6 @@ files = [
|
|||||||
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "claude-agent-sdk"
|
|
||||||
version = "0.1.35"
|
|
||||||
description = "Python SDK for Claude Code"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.10"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
anyio = ">=4.0.0"
|
|
||||||
mcp = ">=0.1.0"
|
|
||||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cleo"
|
name = "cleo"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@@ -1405,14 +1382,14 @@ tzdata = "*"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.128.7"
|
version = "0.128.6"
|
||||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"},
|
{file = "fastapi-0.128.6-py3-none-any.whl", hash = "sha256:bb1c1ef87d6086a7132d0ab60869d6f1ee67283b20fbf84ec0003bd335099509"},
|
||||||
{file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"},
|
{file = "fastapi-0.128.6.tar.gz", hash = "sha256:0cb3946557e792d731b26a42b04912f16367e3c3135ea8290f620e234f2b604f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2616,18 +2593,6 @@ http2 = ["h2 (>=3,<5)"]
|
|||||||
socks = ["socksio (==1.*)"]
|
socks = ["socksio (==1.*)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
zstd = ["zstandard (>=0.18.0)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "httpx-sse"
|
|
||||||
version = "0.4.3"
|
|
||||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.9"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
|
|
||||||
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "1.4.1"
|
version = "1.4.1"
|
||||||
@@ -3152,14 +3117,14 @@ urllib3 = ">=1.26.0,<3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "launchdarkly-server-sdk"
|
name = "launchdarkly-server-sdk"
|
||||||
version = "9.15.0"
|
version = "9.14.1"
|
||||||
description = "LaunchDarkly SDK for Python"
|
description = "LaunchDarkly SDK for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
||||||
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -3345,39 +3310,6 @@ files = [
|
|||||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "mcp"
|
|
||||||
version = "1.26.0"
|
|
||||||
description = "Model Context Protocol SDK"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.10"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
|
|
||||||
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
anyio = ">=4.5"
|
|
||||||
httpx = ">=0.27.1"
|
|
||||||
httpx-sse = ">=0.4"
|
|
||||||
jsonschema = ">=4.20.0"
|
|
||||||
pydantic = ">=2.11.0,<3.0.0"
|
|
||||||
pydantic-settings = ">=2.5.2"
|
|
||||||
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
|
||||||
python-multipart = ">=0.0.9"
|
|
||||||
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
|
|
||||||
sse-starlette = ">=1.6.1"
|
|
||||||
starlette = ">=0.27"
|
|
||||||
typing-extensions = ">=4.9.0"
|
|
||||||
typing-inspection = ">=0.4.1"
|
|
||||||
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
|
|
||||||
rich = ["rich (>=13.9.4)"]
|
|
||||||
ws = ["websockets (>=15.0.1)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mdurl"
|
name = "mdurl"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
@@ -4796,14 +4728,14 @@ tests = ["coverage-conditional-plugin (>=0.9.0)", "portalocker[redis]", "pytest
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "postgrest"
|
name = "postgrest"
|
||||||
version = "2.28.0"
|
version = "2.27.3"
|
||||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
{file = "postgrest-2.27.3-py3-none-any.whl", hash = "sha256:ed79123af7127edd78d538bfe8351d277e45b1a36994a4dbf57ae27dde87a7b7"},
|
||||||
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
{file = "postgrest-2.27.3.tar.gz", hash = "sha256:c2e2679addfc8eaab23197bad7ddaee6cbb4cbe8c483ebd2d2e5219543037cc3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -6062,7 +5994,7 @@ description = "Python for Window Extensions"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
markers = "platform_system == \"Windows\""
|
||||||
files = [
|
files = [
|
||||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||||
@@ -6328,14 +6260,14 @@ all = ["numpy"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "realtime"
|
name = "realtime"
|
||||||
version = "2.28.0"
|
version = "2.27.3"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
{file = "realtime-2.27.3-py3-none-any.whl", hash = "sha256:f571115f86988e33c41c895cb3fba2eaa1b693aeaede3617288f44274ca90f43"},
|
||||||
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
{file = "realtime-2.27.3.tar.gz", hash = "sha256:02b082243107656a5ef3fb63e8e2ab4c40bc199abb45adb8a42ed63f089a1041"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -7042,28 +6974,6 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
|||||||
pymysql = ["pymysql"]
|
pymysql = ["pymysql"]
|
||||||
sqlcipher = ["sqlcipher3_binary"]
|
sqlcipher = ["sqlcipher3_binary"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "sse-starlette"
|
|
||||||
version = "3.2.0"
|
|
||||||
description = "SSE plugin for Starlette"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.9"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
|
|
||||||
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
anyio = ">=4.7.0"
|
|
||||||
starlette = ">=0.49.1"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
daphne = ["daphne (>=4.2.0)"]
|
|
||||||
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
|
|
||||||
granian = ["granian (>=2.3.1)"]
|
|
||||||
uvicorn = ["uvicorn (>=0.34.0)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stagehand"
|
name = "stagehand"
|
||||||
version = "0.5.9"
|
version = "0.5.9"
|
||||||
@@ -7114,14 +7024,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "storage3"
|
name = "storage3"
|
||||||
version = "2.28.0"
|
version = "2.27.3"
|
||||||
description = "Supabase Storage client for Python."
|
description = "Supabase Storage client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
{file = "storage3-2.27.3-py3-none-any.whl", hash = "sha256:11a05b7da84bccabeeea12d940bca3760cf63fe6ca441868677335cfe4fdfbe0"},
|
||||||
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
{file = "storage3-2.27.3.tar.gz", hash = "sha256:dc1a4a010cf36d5482c5cb6c1c28fc5f00e23284342b89e4ae43b5eae8501ddb"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -7181,35 +7091,35 @@ typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase"
|
name = "supabase"
|
||||||
version = "2.28.0"
|
version = "2.27.3"
|
||||||
description = "Supabase client for Python."
|
description = "Supabase client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
{file = "supabase-2.27.3-py3-none-any.whl", hash = "sha256:082a74642fcf9954693f1ce8c251baf23e4bda26ffdbc8dcd4c99c82e60d69ff"},
|
||||||
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
{file = "supabase-2.27.3.tar.gz", hash = "sha256:5e5a348232ac4315c1032ddd687278f0b982465471f0cbb52bca7e6a66495ff3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
httpx = ">=0.26,<0.29"
|
httpx = ">=0.26,<0.29"
|
||||||
postgrest = "2.28.0"
|
postgrest = "2.27.3"
|
||||||
realtime = "2.28.0"
|
realtime = "2.27.3"
|
||||||
storage3 = "2.28.0"
|
storage3 = "2.27.3"
|
||||||
supabase-auth = "2.28.0"
|
supabase-auth = "2.27.3"
|
||||||
supabase-functions = "2.28.0"
|
supabase-functions = "2.27.3"
|
||||||
yarl = ">=1.22.0"
|
yarl = ">=1.22.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-auth"
|
name = "supabase-auth"
|
||||||
version = "2.28.0"
|
version = "2.27.3"
|
||||||
description = "Python Client Library for Supabase Auth"
|
description = "Python Client Library for Supabase Auth"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
{file = "supabase_auth-2.27.3-py3-none-any.whl", hash = "sha256:82a4262eaad85383319d394dab0eea11fcf3ebd774062aef8ea3874ae2f02579"},
|
||||||
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
{file = "supabase_auth-2.27.3.tar.gz", hash = "sha256:39894d4bc60b6f23b5cff4d0d7d4c1659e5d69563cadf014d4896f780ca8ca78"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -7219,14 +7129,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-functions"
|
name = "supabase-functions"
|
||||||
version = "2.28.0"
|
version = "2.27.3"
|
||||||
description = "Library for Supabase Functions"
|
description = "Library for Supabase Functions"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
{file = "supabase_functions-2.27.3-py3-none-any.whl", hash = "sha256:9d14a931d49ede1c6cf5fbfceb11c44061535ba1c3f310f15384964d86a83d9e"},
|
||||||
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
{file = "supabase_functions-2.27.3.tar.gz", hash = "sha256:e954f1646da8ca6e7e16accef58d0884a5f97b25956ee98e7d4927a210ed92f9"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -8530,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<3.14"
|
python-versions = ">=3.10,<3.14"
|
||||||
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
|
content-hash = "c06e96ad49388ba7a46786e9ea55ea2c1a57408e15613237b4bee40a592a12af"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ anthropic = "^0.79.0"
|
|||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
claude-agent-sdk = "^0.1.0"
|
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^46.0"
|
cryptography = "^46.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
@@ -66,7 +65,7 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.28.0"
|
supabase = "2.27.3"
|
||||||
tenacity = "^9.1.4"
|
tenacity = "^9.1.4"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
|
||||||
|
|
||||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
|
||||||
They validate that the security hooks correctly block unauthorized paths,
|
|
||||||
tool access, and dangerous input patterns.
|
|
||||||
|
|
||||||
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
|
|
||||||
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
|
|
||||||
(unshare --net) making command-level parsing unnecessary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.security_hooks import (
|
|
||||||
_validate_tool_access,
|
|
||||||
_validate_workspace_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-test-session"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_denied(result: dict) -> bool:
|
|
||||||
hook = result.get("hookSpecificOutput", {})
|
|
||||||
return hook.get("permissionDecision") == "deny"
|
|
||||||
|
|
||||||
|
|
||||||
def _reason(result: dict) -> str:
|
|
||||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Workspace path validation (Read, Write, Edit, etc.)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkspacePathValidation:
|
|
||||||
def test_path_in_workspace(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_path_outside_workspace(self):
|
|
||||||
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_tool_results_allowed(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read",
|
|
||||||
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
|
|
||||||
SDK_CWD,
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_claude_settings_blocked(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_claude_projects_without_tool_results(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_no_path_allowed(self):
|
|
||||||
"""Glob/Grep without path defaults to cwd — should be allowed."""
|
|
||||||
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_path_traversal_with_dotdot(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Tool access validation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolAccessValidation:
|
|
||||||
def test_blocked_tools(self):
|
|
||||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
|
||||||
result = _validate_tool_access(tool, {})
|
|
||||||
assert _is_denied(result), f"Tool '{tool}' should be blocked"
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked(self):
|
|
||||||
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
|
|
||||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
assert "Bash" in _reason(result)
|
|
||||||
|
|
||||||
def test_workspace_tools_delegate(self):
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_dangerous_pattern_blocked(self):
|
|
||||||
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_safe_unknown_tool_allowed(self):
|
|
||||||
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Deny message quality (ntindle feedback)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestDenyMessageClarity:
|
|
||||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
|
|
||||||
so the model knows the restriction is enforced, not a suggestion."""
|
|
||||||
|
|
||||||
def test_blocked_tool_message(self):
|
|
||||||
reason = _reason(_validate_tool_access("bash", {}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked_message(self):
|
|
||||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
def test_workspace_path_message(self):
|
|
||||||
reason = _reason(
|
|
||||||
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
|
||||||
)
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
@@ -1,255 +0,0 @@
|
|||||||
"""Unit tests for JSONL transcript management utilities."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.transcript import (
|
|
||||||
STRIPPABLE_TYPES,
|
|
||||||
read_transcript_file,
|
|
||||||
strip_progress_entries,
|
|
||||||
validate_transcript,
|
|
||||||
write_transcript_to_tempfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_jsonl(*entries: dict) -> str:
|
|
||||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
# --- Fixtures ---
|
|
||||||
|
|
||||||
|
|
||||||
METADATA_LINE = {"type": "queue-operation", "subtype": "create"}
|
|
||||||
FILE_HISTORY = {"type": "file-history-snapshot", "files": []}
|
|
||||||
USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
|
||||||
ASST_MSG = {
|
|
||||||
"type": "assistant",
|
|
||||||
"uuid": "a1",
|
|
||||||
"parentUuid": "u1",
|
|
||||||
"message": {"role": "assistant", "content": "hello"},
|
|
||||||
}
|
|
||||||
PROGRESS_ENTRY = {
|
|
||||||
"type": "progress",
|
|
||||||
"uuid": "p1",
|
|
||||||
"parentUuid": "u1",
|
|
||||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
|
||||||
}
|
|
||||||
|
|
||||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
|
||||||
|
|
||||||
|
|
||||||
# --- read_transcript_file ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestReadTranscriptFile:
|
|
||||||
def test_returns_content_for_valid_file(self, tmp_path):
|
|
||||||
path = tmp_path / "session.jsonl"
|
|
||||||
path.write_text(VALID_TRANSCRIPT)
|
|
||||||
result = read_transcript_file(str(path))
|
|
||||||
assert result is not None
|
|
||||||
assert "user" in result
|
|
||||||
|
|
||||||
def test_returns_none_for_missing_file(self):
|
|
||||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
|
||||||
|
|
||||||
def test_returns_none_for_empty_path(self):
|
|
||||||
assert read_transcript_file("") is None
|
|
||||||
|
|
||||||
def test_returns_none_for_empty_file(self, tmp_path):
|
|
||||||
path = tmp_path / "empty.jsonl"
|
|
||||||
path.write_text("")
|
|
||||||
assert read_transcript_file(str(path)) is None
|
|
||||||
|
|
||||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
|
||||||
path = tmp_path / "meta.jsonl"
|
|
||||||
path.write_text(content)
|
|
||||||
assert read_transcript_file(str(path)) is None
|
|
||||||
|
|
||||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
|
||||||
path = tmp_path / "bad.jsonl"
|
|
||||||
path.write_text("not json\n{}\n{}\n")
|
|
||||||
assert read_transcript_file(str(path)) is None
|
|
||||||
|
|
||||||
def test_no_size_limit(self, tmp_path):
|
|
||||||
"""Large files are accepted — bucket storage has no size limit."""
|
|
||||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
|
||||||
path = tmp_path / "big.jsonl"
|
|
||||||
path.write_text(content)
|
|
||||||
result = read_transcript_file(str(path))
|
|
||||||
assert result is not None
|
|
||||||
|
|
||||||
|
|
||||||
# --- write_transcript_to_tempfile ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestWriteTranscriptToTempfile:
|
|
||||||
"""Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check."""
|
|
||||||
|
|
||||||
def test_writes_file_and_returns_path(self):
|
|
||||||
cwd = "/tmp/copilot-test-write"
|
|
||||||
try:
|
|
||||||
result = write_transcript_to_tempfile(
|
|
||||||
VALID_TRANSCRIPT, "sess-1234-abcd", cwd
|
|
||||||
)
|
|
||||||
assert result is not None
|
|
||||||
assert os.path.isfile(result)
|
|
||||||
assert result.endswith(".jsonl")
|
|
||||||
with open(result) as f:
|
|
||||||
assert f.read() == VALID_TRANSCRIPT
|
|
||||||
finally:
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(cwd, ignore_errors=True)
|
|
||||||
|
|
||||||
def test_creates_parent_directory(self):
|
|
||||||
cwd = "/tmp/copilot-test-mkdir"
|
|
||||||
try:
|
|
||||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
|
||||||
assert result is not None
|
|
||||||
assert os.path.isdir(cwd)
|
|
||||||
finally:
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(cwd, ignore_errors=True)
|
|
||||||
|
|
||||||
def test_uses_session_id_prefix(self):
|
|
||||||
cwd = "/tmp/copilot-test-prefix"
|
|
||||||
try:
|
|
||||||
result = write_transcript_to_tempfile(
|
|
||||||
VALID_TRANSCRIPT, "abcdef12-rest", cwd
|
|
||||||
)
|
|
||||||
assert result is not None
|
|
||||||
assert "abcdef12" in os.path.basename(result)
|
|
||||||
finally:
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(cwd, ignore_errors=True)
|
|
||||||
|
|
||||||
def test_rejects_cwd_outside_sandbox(self, tmp_path):
|
|
||||||
cwd = str(tmp_path / "not-copilot")
|
|
||||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
# --- validate_transcript ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidateTranscript:
|
|
||||||
def test_valid_transcript(self):
|
|
||||||
assert validate_transcript(VALID_TRANSCRIPT) is True
|
|
||||||
|
|
||||||
def test_none_content(self):
|
|
||||||
assert validate_transcript(None) is False
|
|
||||||
|
|
||||||
def test_empty_content(self):
|
|
||||||
assert validate_transcript("") is False
|
|
||||||
|
|
||||||
def test_metadata_only(self):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
|
||||||
assert validate_transcript(content) is False
|
|
||||||
|
|
||||||
def test_user_only_no_assistant(self):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG)
|
|
||||||
assert validate_transcript(content) is False
|
|
||||||
|
|
||||||
def test_assistant_only_no_user(self):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
|
||||||
assert validate_transcript(content) is False
|
|
||||||
|
|
||||||
def test_invalid_json_returns_false(self):
|
|
||||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
|
||||||
|
|
||||||
|
|
||||||
# --- strip_progress_entries ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestStripProgressEntries:
|
|
||||||
def test_strips_all_strippable_types(self):
|
|
||||||
"""All STRIPPABLE_TYPES are removed from the output."""
|
|
||||||
entries = [
|
|
||||||
USER_MSG,
|
|
||||||
{"type": "progress", "uuid": "p1", "parentUuid": "u1"},
|
|
||||||
{"type": "file-history-snapshot", "files": []},
|
|
||||||
{"type": "queue-operation", "subtype": "create"},
|
|
||||||
{"type": "summary", "text": "..."},
|
|
||||||
{"type": "pr-link", "url": "..."},
|
|
||||||
ASST_MSG,
|
|
||||||
]
|
|
||||||
result = strip_progress_entries(_make_jsonl(*entries))
|
|
||||||
result_types = {json.loads(line)["type"] for line in result.strip().split("\n")}
|
|
||||||
assert result_types == {"user", "assistant"}
|
|
||||||
for stype in STRIPPABLE_TYPES:
|
|
||||||
assert stype not in result_types
|
|
||||||
|
|
||||||
def test_reparents_children_of_stripped_entries(self):
|
|
||||||
"""An assistant message whose parent is a progress entry gets reparented."""
|
|
||||||
progress = {
|
|
||||||
"type": "progress",
|
|
||||||
"uuid": "p1",
|
|
||||||
"parentUuid": "u1",
|
|
||||||
"data": {"type": "bash_progress"},
|
|
||||||
}
|
|
||||||
asst = {
|
|
||||||
"type": "assistant",
|
|
||||||
"uuid": "a1",
|
|
||||||
"parentUuid": "p1", # Points to progress
|
|
||||||
"message": {"role": "assistant", "content": "done"},
|
|
||||||
}
|
|
||||||
content = _make_jsonl(USER_MSG, progress, asst)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
|
||||||
|
|
||||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
|
||||||
# Should be reparented to u1 (the user message)
|
|
||||||
assert asst_entry["parentUuid"] == "u1"
|
|
||||||
|
|
||||||
def test_reparents_through_chain(self):
|
|
||||||
"""Reparenting walks through multiple stripped entries."""
|
|
||||||
p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
|
||||||
p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"}
|
|
||||||
p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"}
|
|
||||||
asst = {
|
|
||||||
"type": "assistant",
|
|
||||||
"uuid": "a1",
|
|
||||||
"parentUuid": "p3", # 3 levels deep
|
|
||||||
"message": {"role": "assistant", "content": "done"},
|
|
||||||
}
|
|
||||||
content = _make_jsonl(USER_MSG, p1, p2, p3, asst)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
|
||||||
|
|
||||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
|
||||||
assert asst_entry["parentUuid"] == "u1"
|
|
||||||
|
|
||||||
def test_preserves_non_strippable_entries(self):
|
|
||||||
"""User, assistant, and system entries are preserved."""
|
|
||||||
system = {"type": "system", "uuid": "s1", "message": "prompt"}
|
|
||||||
content = _make_jsonl(system, USER_MSG, ASST_MSG)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
|
||||||
assert result_types == ["system", "user", "assistant"]
|
|
||||||
|
|
||||||
def test_empty_input(self):
|
|
||||||
result = strip_progress_entries("")
|
|
||||||
# Should return just a newline (empty content stripped)
|
|
||||||
assert result.strip() == ""
|
|
||||||
|
|
||||||
def test_no_strippable_entries(self):
|
|
||||||
"""When there's nothing to strip, output matches input structure."""
|
|
||||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
result_lines = result.strip().split("\n")
|
|
||||||
assert len(result_lines) == 2
|
|
||||||
|
|
||||||
def test_handles_entries_without_uuid(self):
|
|
||||||
"""Entries without uuid field are handled gracefully."""
|
|
||||||
no_uuid = {"type": "queue-operation", "subtype": "create"}
|
|
||||||
content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
|
||||||
# queue-operation is strippable
|
|
||||||
assert "queue-operation" not in result_types
|
|
||||||
assert "user" in result_types
|
|
||||||
assert "assistant" in result_types
|
|
||||||
@@ -37,7 +37,7 @@ services:
|
|||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: migrate
|
target: migrate
|
||||||
command: ["sh", "-c", "prisma generate && python3 gen_prisma_types_stub.py && prisma migrate deploy"]
|
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
@@ -56,7 +56,7 @@ services:
|
|||||||
test:
|
test:
|
||||||
[
|
[
|
||||||
"CMD-SHELL",
|
"CMD-SHELL",
|
||||||
"prisma migrate status | grep -q 'No pending migrations' || exit 1",
|
"poetry run prisma migrate status | grep -q 'No pending migrations' || exit 1",
|
||||||
]
|
]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
|
|||||||
@@ -22,11 +22,6 @@ Sentry.init({
|
|||||||
|
|
||||||
enabled: shouldEnable,
|
enabled: shouldEnable,
|
||||||
|
|
||||||
// Suppress cross-origin stylesheet errors from Sentry Replay (rrweb)
|
|
||||||
// serializing DOM snapshots with cross-origin stylesheets
|
|
||||||
// (e.g., from browser extensions or CDN-loaded CSS)
|
|
||||||
ignoreErrors: [/Not allowed to access cross-origin stylesheet/],
|
|
||||||
|
|
||||||
// Add optional integrations for additional features
|
// Add optional integrations for additional features
|
||||||
integrations: [
|
integrations: [
|
||||||
Sentry.captureConsoleIntegration(),
|
Sentry.captureConsoleIntegration(),
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
|||||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||||
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
||||||
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
||||||
import { GenericTool } from "../../tools/GenericTool/GenericTool";
|
|
||||||
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -256,16 +255,6 @@ export const ChatMessagesContainer = ({
|
|||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
default:
|
default:
|
||||||
// Render a generic tool indicator for SDK built-in
|
|
||||||
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
|
||||||
if (part.type.startsWith("tool-")) {
|
|
||||||
return (
|
|
||||||
<GenericTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
})}
|
})}
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { ToolUIPart } from "ai";
|
|
||||||
import { GearIcon } from "@phosphor-icons/react";
|
|
||||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
part: ToolUIPart;
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolName(part: ToolUIPart): string {
|
|
||||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
|
||||||
return part.type.replace(/^tool-/, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatToolName(name: string): string {
|
|
||||||
// "search_docs" → "Search docs", "Read" → "Read"
|
|
||||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
|
||||||
}
|
|
||||||
|
|
||||||
function getAnimationText(part: ToolUIPart): string {
|
|
||||||
const label = formatToolName(extractToolName(part));
|
|
||||||
|
|
||||||
switch (part.state) {
|
|
||||||
case "input-streaming":
|
|
||||||
case "input-available":
|
|
||||||
return `Running ${label}…`;
|
|
||||||
case "output-available":
|
|
||||||
return `${label} completed`;
|
|
||||||
case "output-error":
|
|
||||||
return `${label} failed`;
|
|
||||||
default:
|
|
||||||
return `Running ${label}…`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function GenericTool({ part }: Props) {
|
|
||||||
const isStreaming =
|
|
||||||
part.state === "input-streaming" || part.state === "input-available";
|
|
||||||
const isError = part.state === "output-error";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="py-2">
|
|
||||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
|
||||||
<GearIcon
|
|
||||||
size={14}
|
|
||||||
weight="regular"
|
|
||||||
className={
|
|
||||||
isError
|
|
||||||
? "text-red-500"
|
|
||||||
: isStreaming
|
|
||||||
? "animate-spin text-neutral-500"
|
|
||||||
: "text-neutral-400"
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<MorphingTextAnimation
|
|
||||||
text={getAnimationText(part)}
|
|
||||||
className={isError ? "text-red-500" : undefined}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
import type { ToolUIPart } from "ai";
|
import type { ToolUIPart } from "ai";
|
||||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||||
import { BlockDetailsCard } from "./components/BlockDetailsCard/BlockDetailsCard";
|
|
||||||
import { BlockOutputCard } from "./components/BlockOutputCard/BlockOutputCard";
|
import { BlockOutputCard } from "./components/BlockOutputCard/BlockOutputCard";
|
||||||
import { ErrorCard } from "./components/ErrorCard/ErrorCard";
|
import { ErrorCard } from "./components/ErrorCard/ErrorCard";
|
||||||
import { SetupRequirementsCard } from "./components/SetupRequirementsCard/SetupRequirementsCard";
|
import { SetupRequirementsCard } from "./components/SetupRequirementsCard/SetupRequirementsCard";
|
||||||
@@ -12,7 +11,6 @@ import {
|
|||||||
getAnimationText,
|
getAnimationText,
|
||||||
getRunBlockToolOutput,
|
getRunBlockToolOutput,
|
||||||
isRunBlockBlockOutput,
|
isRunBlockBlockOutput,
|
||||||
isRunBlockDetailsOutput,
|
|
||||||
isRunBlockErrorOutput,
|
isRunBlockErrorOutput,
|
||||||
isRunBlockSetupRequirementsOutput,
|
isRunBlockSetupRequirementsOutput,
|
||||||
ToolIcon,
|
ToolIcon,
|
||||||
@@ -43,7 +41,6 @@ export function RunBlockTool({ part }: Props) {
|
|||||||
part.state === "output-available" &&
|
part.state === "output-available" &&
|
||||||
!!output &&
|
!!output &&
|
||||||
(isRunBlockBlockOutput(output) ||
|
(isRunBlockBlockOutput(output) ||
|
||||||
isRunBlockDetailsOutput(output) ||
|
|
||||||
isRunBlockSetupRequirementsOutput(output) ||
|
isRunBlockSetupRequirementsOutput(output) ||
|
||||||
isRunBlockErrorOutput(output));
|
isRunBlockErrorOutput(output));
|
||||||
|
|
||||||
@@ -61,10 +58,6 @@ export function RunBlockTool({ part }: Props) {
|
|||||||
<ToolAccordion {...getAccordionMeta(output)}>
|
<ToolAccordion {...getAccordionMeta(output)}>
|
||||||
{isRunBlockBlockOutput(output) && <BlockOutputCard output={output} />}
|
{isRunBlockBlockOutput(output) && <BlockOutputCard output={output} />}
|
||||||
|
|
||||||
{isRunBlockDetailsOutput(output) && (
|
|
||||||
<BlockDetailsCard output={output} />
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isRunBlockSetupRequirementsOutput(output) && (
|
{isRunBlockSetupRequirementsOutput(output) && (
|
||||||
<SetupRequirementsCard output={output} />
|
<SetupRequirementsCard output={output} />
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -1,188 +0,0 @@
|
|||||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
|
||||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
|
||||||
import type { BlockDetailsResponse } from "../../helpers";
|
|
||||||
import { BlockDetailsCard } from "./BlockDetailsCard";
|
|
||||||
|
|
||||||
const meta: Meta<typeof BlockDetailsCard> = {
|
|
||||||
title: "Copilot/RunBlock/BlockDetailsCard",
|
|
||||||
component: BlockDetailsCard,
|
|
||||||
parameters: {
|
|
||||||
layout: "centered",
|
|
||||||
},
|
|
||||||
tags: ["autodocs"],
|
|
||||||
decorators: [
|
|
||||||
(Story) => (
|
|
||||||
<div style={{ maxWidth: 480 }}>
|
|
||||||
<Story />
|
|
||||||
</div>
|
|
||||||
),
|
|
||||||
],
|
|
||||||
};
|
|
||||||
|
|
||||||
export default meta;
|
|
||||||
type Story = StoryObj<typeof meta>;
|
|
||||||
|
|
||||||
const baseBlock: BlockDetailsResponse = {
|
|
||||||
type: ResponseType.block_details,
|
|
||||||
message:
|
|
||||||
"Here are the details for the GetWeather block. Provide the required inputs to run it.",
|
|
||||||
session_id: "session-123",
|
|
||||||
user_authenticated: true,
|
|
||||||
block: {
|
|
||||||
id: "block-abc-123",
|
|
||||||
name: "GetWeather",
|
|
||||||
description: "Fetches current weather data for a given location.",
|
|
||||||
inputs: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
location: {
|
|
||||||
title: "Location",
|
|
||||||
type: "string",
|
|
||||||
description:
|
|
||||||
"City name or coordinates (e.g. 'London' or '51.5,-0.1')",
|
|
||||||
},
|
|
||||||
units: {
|
|
||||||
title: "Units",
|
|
||||||
type: "string",
|
|
||||||
description: "Temperature units: 'metric' or 'imperial'",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ["location"],
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
temperature: {
|
|
||||||
title: "Temperature",
|
|
||||||
type: "number",
|
|
||||||
description: "Current temperature in the requested units",
|
|
||||||
},
|
|
||||||
condition: {
|
|
||||||
title: "Condition",
|
|
||||||
type: "string",
|
|
||||||
description: "Weather condition description (e.g. 'Sunny', 'Rain')",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
credentials: [],
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export const Default: Story = {
|
|
||||||
args: {
|
|
||||||
output: baseBlock,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export const InputsOnly: Story = {
|
|
||||||
args: {
|
|
||||||
output: {
|
|
||||||
...baseBlock,
|
|
||||||
message: "This block requires inputs. No outputs are defined.",
|
|
||||||
block: {
|
|
||||||
...baseBlock.block,
|
|
||||||
outputs: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export const OutputsOnly: Story = {
|
|
||||||
args: {
|
|
||||||
output: {
|
|
||||||
...baseBlock,
|
|
||||||
message: "This block has no required inputs.",
|
|
||||||
block: {
|
|
||||||
...baseBlock.block,
|
|
||||||
inputs: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export const ManyFields: Story = {
|
|
||||||
args: {
|
|
||||||
output: {
|
|
||||||
...baseBlock,
|
|
||||||
message: "Block with many input and output fields.",
|
|
||||||
block: {
|
|
||||||
...baseBlock.block,
|
|
||||||
name: "SendEmail",
|
|
||||||
description: "Sends an email via SMTP.",
|
|
||||||
inputs: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
to: {
|
|
||||||
title: "To",
|
|
||||||
type: "string",
|
|
||||||
description: "Recipient email address",
|
|
||||||
},
|
|
||||||
subject: {
|
|
||||||
title: "Subject",
|
|
||||||
type: "string",
|
|
||||||
description: "Email subject line",
|
|
||||||
},
|
|
||||||
body: {
|
|
||||||
title: "Body",
|
|
||||||
type: "string",
|
|
||||||
description: "Email body content",
|
|
||||||
},
|
|
||||||
cc: {
|
|
||||||
title: "CC",
|
|
||||||
type: "string",
|
|
||||||
description: "CC recipients (comma-separated)",
|
|
||||||
},
|
|
||||||
bcc: {
|
|
||||||
title: "BCC",
|
|
||||||
type: "string",
|
|
||||||
description: "BCC recipients (comma-separated)",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ["to", "subject", "body"],
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
message_id: {
|
|
||||||
title: "Message ID",
|
|
||||||
type: "string",
|
|
||||||
description: "Unique ID of the sent email",
|
|
||||||
},
|
|
||||||
status: {
|
|
||||||
title: "Status",
|
|
||||||
type: "string",
|
|
||||||
description: "Delivery status",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export const NoFieldDescriptions: Story = {
|
|
||||||
args: {
|
|
||||||
output: {
|
|
||||||
...baseBlock,
|
|
||||||
message: "Fields without descriptions.",
|
|
||||||
block: {
|
|
||||||
...baseBlock.block,
|
|
||||||
name: "SimpleBlock",
|
|
||||||
inputs: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
input_a: { title: "Input A", type: "string" },
|
|
||||||
input_b: { title: "Input B", type: "number" },
|
|
||||||
},
|
|
||||||
required: ["input_a"],
|
|
||||||
},
|
|
||||||
outputs: {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
result: { title: "Result", type: "string" },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import type { BlockDetailsResponse } from "../../helpers";
|
|
||||||
import {
|
|
||||||
ContentBadge,
|
|
||||||
ContentCard,
|
|
||||||
ContentCardDescription,
|
|
||||||
ContentCardTitle,
|
|
||||||
ContentGrid,
|
|
||||||
ContentMessage,
|
|
||||||
} from "../../../../components/ToolAccordion/AccordionContent";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
output: BlockDetailsResponse;
|
|
||||||
}
|
|
||||||
|
|
||||||
function SchemaFieldList({
|
|
||||||
title,
|
|
||||||
properties,
|
|
||||||
required,
|
|
||||||
}: {
|
|
||||||
title: string;
|
|
||||||
properties: Record<string, unknown>;
|
|
||||||
required?: string[];
|
|
||||||
}) {
|
|
||||||
const entries = Object.entries(properties);
|
|
||||||
if (entries.length === 0) return null;
|
|
||||||
|
|
||||||
const requiredSet = new Set(required ?? []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<ContentCard>
|
|
||||||
<ContentCardTitle className="text-xs">{title}</ContentCardTitle>
|
|
||||||
<div className="mt-2 grid gap-2">
|
|
||||||
{entries.map(([name, schema]) => {
|
|
||||||
const field = schema as Record<string, unknown> | undefined;
|
|
||||||
const fieldTitle =
|
|
||||||
typeof field?.title === "string" ? field.title : name;
|
|
||||||
const fieldType =
|
|
||||||
typeof field?.type === "string" ? field.type : "unknown";
|
|
||||||
const description =
|
|
||||||
typeof field?.description === "string"
|
|
||||||
? field.description
|
|
||||||
: undefined;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div key={name} className="rounded-xl border p-2">
|
|
||||||
<div className="flex items-center justify-between gap-2">
|
|
||||||
<ContentCardTitle className="text-xs">
|
|
||||||
{fieldTitle}
|
|
||||||
</ContentCardTitle>
|
|
||||||
<div className="flex gap-1">
|
|
||||||
<ContentBadge>{fieldType}</ContentBadge>
|
|
||||||
{requiredSet.has(name) && (
|
|
||||||
<ContentBadge>Required</ContentBadge>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{description && (
|
|
||||||
<ContentCardDescription className="mt-1 text-xs">
|
|
||||||
{description}
|
|
||||||
</ContentCardDescription>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</div>
|
|
||||||
</ContentCard>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function BlockDetailsCard({ output }: Props) {
|
|
||||||
const inputs = output.block.inputs as {
|
|
||||||
properties?: Record<string, unknown>;
|
|
||||||
required?: string[];
|
|
||||||
} | null;
|
|
||||||
const outputs = output.block.outputs as {
|
|
||||||
properties?: Record<string, unknown>;
|
|
||||||
required?: string[];
|
|
||||||
} | null;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<ContentGrid>
|
|
||||||
<ContentMessage>{output.message}</ContentMessage>
|
|
||||||
|
|
||||||
{inputs?.properties && Object.keys(inputs.properties).length > 0 && (
|
|
||||||
<SchemaFieldList
|
|
||||||
title="Inputs"
|
|
||||||
properties={inputs.properties}
|
|
||||||
required={inputs.required}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{outputs?.properties && Object.keys(outputs.properties).length > 0 && (
|
|
||||||
<SchemaFieldList
|
|
||||||
title="Outputs"
|
|
||||||
properties={outputs.properties}
|
|
||||||
required={outputs.required}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</ContentGrid>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -10,37 +10,18 @@ import {
|
|||||||
import type { ToolUIPart } from "ai";
|
import type { ToolUIPart } from "ai";
|
||||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||||
|
|
||||||
/** Block details returned on first run_block attempt (before input_data provided). */
|
|
||||||
export interface BlockDetailsResponse {
|
|
||||||
type: typeof ResponseType.block_details;
|
|
||||||
message: string;
|
|
||||||
session_id?: string | null;
|
|
||||||
block: {
|
|
||||||
id: string;
|
|
||||||
name: string;
|
|
||||||
description: string;
|
|
||||||
inputs: Record<string, unknown>;
|
|
||||||
outputs: Record<string, unknown>;
|
|
||||||
credentials: unknown[];
|
|
||||||
};
|
|
||||||
user_authenticated: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface RunBlockInput {
|
export interface RunBlockInput {
|
||||||
block_id?: string;
|
block_id?: string;
|
||||||
block_name?: string;
|
|
||||||
input_data?: Record<string, unknown>;
|
input_data?: Record<string, unknown>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type RunBlockToolOutput =
|
export type RunBlockToolOutput =
|
||||||
| SetupRequirementsResponse
|
| SetupRequirementsResponse
|
||||||
| BlockDetailsResponse
|
|
||||||
| BlockOutputResponse
|
| BlockOutputResponse
|
||||||
| ErrorResponse;
|
| ErrorResponse;
|
||||||
|
|
||||||
const RUN_BLOCK_OUTPUT_TYPES = new Set<string>([
|
const RUN_BLOCK_OUTPUT_TYPES = new Set<string>([
|
||||||
ResponseType.setup_requirements,
|
ResponseType.setup_requirements,
|
||||||
ResponseType.block_details,
|
|
||||||
ResponseType.block_output,
|
ResponseType.block_output,
|
||||||
ResponseType.error,
|
ResponseType.error,
|
||||||
]);
|
]);
|
||||||
@@ -54,15 +35,6 @@ export function isRunBlockSetupRequirementsOutput(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRunBlockDetailsOutput(
|
|
||||||
output: RunBlockToolOutput,
|
|
||||||
): output is BlockDetailsResponse {
|
|
||||||
return (
|
|
||||||
output.type === ResponseType.block_details ||
|
|
||||||
("block" in output && typeof output.block === "object")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isRunBlockBlockOutput(
|
export function isRunBlockBlockOutput(
|
||||||
output: RunBlockToolOutput,
|
output: RunBlockToolOutput,
|
||||||
): output is BlockOutputResponse {
|
): output is BlockOutputResponse {
|
||||||
@@ -92,7 +64,6 @@ function parseOutput(output: unknown): RunBlockToolOutput | null {
|
|||||||
return output as RunBlockToolOutput;
|
return output as RunBlockToolOutput;
|
||||||
}
|
}
|
||||||
if ("block_id" in output) return output as BlockOutputResponse;
|
if ("block_id" in output) return output as BlockOutputResponse;
|
||||||
if ("block" in output) return output as BlockDetailsResponse;
|
|
||||||
if ("setup_info" in output) return output as SetupRequirementsResponse;
|
if ("setup_info" in output) return output as SetupRequirementsResponse;
|
||||||
if ("error" in output || "details" in output)
|
if ("error" in output || "details" in output)
|
||||||
return output as ErrorResponse;
|
return output as ErrorResponse;
|
||||||
@@ -113,25 +84,17 @@ export function getAnimationText(part: {
|
|||||||
output?: unknown;
|
output?: unknown;
|
||||||
}): string {
|
}): string {
|
||||||
const input = part.input as RunBlockInput | undefined;
|
const input = part.input as RunBlockInput | undefined;
|
||||||
const blockName = input?.block_name?.trim();
|
|
||||||
const blockId = input?.block_id?.trim();
|
const blockId = input?.block_id?.trim();
|
||||||
// Prefer block_name if available, otherwise fall back to block_id
|
const blockText = blockId ? ` "${blockId}"` : "";
|
||||||
const blockText = blockName
|
|
||||||
? ` "${blockName}"`
|
|
||||||
: blockId
|
|
||||||
? ` "${blockId}"`
|
|
||||||
: "";
|
|
||||||
|
|
||||||
switch (part.state) {
|
switch (part.state) {
|
||||||
case "input-streaming":
|
case "input-streaming":
|
||||||
case "input-available":
|
case "input-available":
|
||||||
return `Running${blockText}`;
|
return `Running the block${blockText}`;
|
||||||
case "output-available": {
|
case "output-available": {
|
||||||
const output = parseOutput(part.output);
|
const output = parseOutput(part.output);
|
||||||
if (!output) return `Running${blockText}`;
|
if (!output) return `Running the block${blockText}`;
|
||||||
if (isRunBlockBlockOutput(output)) return `Ran "${output.block_name}"`;
|
if (isRunBlockBlockOutput(output)) return `Ran "${output.block_name}"`;
|
||||||
if (isRunBlockDetailsOutput(output))
|
|
||||||
return `Details for "${output.block.name}"`;
|
|
||||||
if (isRunBlockSetupRequirementsOutput(output)) {
|
if (isRunBlockSetupRequirementsOutput(output)) {
|
||||||
return `Setup needed for "${output.setup_info.agent_name}"`;
|
return `Setup needed for "${output.setup_info.agent_name}"`;
|
||||||
}
|
}
|
||||||
@@ -195,21 +158,6 @@ export function getAccordionMeta(output: RunBlockToolOutput): {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isRunBlockDetailsOutput(output)) {
|
|
||||||
const inputKeys = Object.keys(
|
|
||||||
(output.block.inputs as { properties?: Record<string, unknown> })
|
|
||||||
?.properties ?? {},
|
|
||||||
);
|
|
||||||
return {
|
|
||||||
icon,
|
|
||||||
title: output.block.name,
|
|
||||||
description:
|
|
||||||
inputKeys.length > 0
|
|
||||||
? `${inputKeys.length} input field${inputKeys.length === 1 ? "" : "s"} available`
|
|
||||||
: output.message,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isRunBlockSetupRequirementsOutput(output)) {
|
if (isRunBlockSetupRequirementsOutput(output)) {
|
||||||
const missingCredsCount = Object.keys(
|
const missingCredsCount = Object.keys(
|
||||||
(output.setup_info.user_readiness?.missing_credentials ?? {}) as Record<
|
(output.setup_info.user_readiness?.missing_credentials ?? {}) as Record<
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ export function ScheduleListItem({
|
|||||||
description={formatDistanceToNow(schedule.next_run_time, {
|
description={formatDistanceToNow(schedule.next_run_time, {
|
||||||
addSuffix: true,
|
addSuffix: true,
|
||||||
})}
|
})}
|
||||||
descriptionTitle={new Date(schedule.next_run_time).toString()}
|
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
selected={selected}
|
selected={selected}
|
||||||
icon={
|
icon={
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import React from "react";
|
|||||||
interface Props {
|
interface Props {
|
||||||
title: string;
|
title: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
descriptionTitle?: string;
|
|
||||||
icon?: React.ReactNode;
|
icon?: React.ReactNode;
|
||||||
selected?: boolean;
|
selected?: boolean;
|
||||||
onClick?: () => void;
|
onClick?: () => void;
|
||||||
@@ -17,7 +16,6 @@ interface Props {
|
|||||||
export function SidebarItemCard({
|
export function SidebarItemCard({
|
||||||
title,
|
title,
|
||||||
description,
|
description,
|
||||||
descriptionTitle,
|
|
||||||
icon,
|
icon,
|
||||||
selected,
|
selected,
|
||||||
onClick,
|
onClick,
|
||||||
@@ -40,11 +38,7 @@ export function SidebarItemCard({
|
|||||||
>
|
>
|
||||||
{title}
|
{title}
|
||||||
</Text>
|
</Text>
|
||||||
<Text
|
<Text variant="body" className="leading-tight !text-zinc-500">
|
||||||
variant="body"
|
|
||||||
className="leading-tight !text-zinc-500"
|
|
||||||
title={descriptionTitle}
|
|
||||||
>
|
|
||||||
{description}
|
{description}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -81,9 +81,6 @@ export function TaskListItem({
|
|||||||
? formatDistanceToNow(run.started_at, { addSuffix: true })
|
? formatDistanceToNow(run.started_at, { addSuffix: true })
|
||||||
: "—"
|
: "—"
|
||||||
}
|
}
|
||||||
descriptionTitle={
|
|
||||||
run.started_at ? new Date(run.started_at).toString() : undefined
|
|
||||||
}
|
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
selected={selected}
|
selected={selected}
|
||||||
actions={
|
actions={
|
||||||
|
|||||||
@@ -1053,7 +1053,6 @@
|
|||||||
"$ref": "#/components/schemas/ClarificationNeededResponse"
|
"$ref": "#/components/schemas/ClarificationNeededResponse"
|
||||||
},
|
},
|
||||||
{ "$ref": "#/components/schemas/BlockListResponse" },
|
{ "$ref": "#/components/schemas/BlockListResponse" },
|
||||||
{ "$ref": "#/components/schemas/BlockDetailsResponse" },
|
|
||||||
{ "$ref": "#/components/schemas/BlockOutputResponse" },
|
{ "$ref": "#/components/schemas/BlockOutputResponse" },
|
||||||
{ "$ref": "#/components/schemas/DocSearchResultsResponse" },
|
{ "$ref": "#/components/schemas/DocSearchResultsResponse" },
|
||||||
{ "$ref": "#/components/schemas/DocPageResponse" },
|
{ "$ref": "#/components/schemas/DocPageResponse" },
|
||||||
@@ -6959,58 +6958,6 @@
|
|||||||
"enum": ["run", "byte", "second"],
|
"enum": ["run", "byte", "second"],
|
||||||
"title": "BlockCostType"
|
"title": "BlockCostType"
|
||||||
},
|
},
|
||||||
"BlockDetails": {
|
|
||||||
"properties": {
|
|
||||||
"id": { "type": "string", "title": "Id" },
|
|
||||||
"name": { "type": "string", "title": "Name" },
|
|
||||||
"description": { "type": "string", "title": "Description" },
|
|
||||||
"inputs": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Inputs",
|
|
||||||
"default": {}
|
|
||||||
},
|
|
||||||
"outputs": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Outputs",
|
|
||||||
"default": {}
|
|
||||||
},
|
|
||||||
"credentials": {
|
|
||||||
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Credentials",
|
|
||||||
"default": []
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["id", "name", "description"],
|
|
||||||
"title": "BlockDetails",
|
|
||||||
"description": "Detailed block information."
|
|
||||||
},
|
|
||||||
"BlockDetailsResponse": {
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"$ref": "#/components/schemas/ResponseType",
|
|
||||||
"default": "block_details"
|
|
||||||
},
|
|
||||||
"message": { "type": "string", "title": "Message" },
|
|
||||||
"session_id": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Session Id"
|
|
||||||
},
|
|
||||||
"block": { "$ref": "#/components/schemas/BlockDetails" },
|
|
||||||
"user_authenticated": {
|
|
||||||
"type": "boolean",
|
|
||||||
"title": "User Authenticated",
|
|
||||||
"default": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["message", "block"],
|
|
||||||
"title": "BlockDetailsResponse",
|
|
||||||
"description": "Response for block details (first run_block attempt)."
|
|
||||||
},
|
|
||||||
"BlockInfo": {
|
"BlockInfo": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
@@ -7075,24 +7022,29 @@
|
|||||||
"input_schema": {
|
"input_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Input Schema",
|
"title": "Input Schema"
|
||||||
"description": "Full JSON schema for block inputs"
|
|
||||||
},
|
},
|
||||||
"output_schema": {
|
"output_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Output Schema",
|
"title": "Output Schema"
|
||||||
"description": "Full JSON schema for block outputs"
|
|
||||||
},
|
},
|
||||||
"required_inputs": {
|
"required_inputs": {
|
||||||
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Required Inputs",
|
"title": "Required Inputs",
|
||||||
"description": "List of input fields for this block"
|
"description": "List of required input fields for this block"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["id", "name", "description", "categories"],
|
"required": [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"categories",
|
||||||
|
"input_schema",
|
||||||
|
"output_schema"
|
||||||
|
],
|
||||||
"title": "BlockInfoSummary",
|
"title": "BlockInfoSummary",
|
||||||
"description": "Summary of a block for search results."
|
"description": "Summary of a block for search results."
|
||||||
},
|
},
|
||||||
@@ -7138,7 +7090,7 @@
|
|||||||
"usage_hint": {
|
"usage_hint": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"title": "Usage Hint",
|
"title": "Usage Hint",
|
||||||
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs."
|
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the required fields from input_schema."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -10532,7 +10484,6 @@
|
|||||||
"agent_saved",
|
"agent_saved",
|
||||||
"clarification_needed",
|
"clarification_needed",
|
||||||
"block_list",
|
"block_list",
|
||||||
"block_details",
|
|
||||||
"block_output",
|
"block_output",
|
||||||
"doc_search_results",
|
"doc_search_results",
|
||||||
"doc_page",
|
"doc_page",
|
||||||
@@ -10544,10 +10495,7 @@
|
|||||||
"operation_started",
|
"operation_started",
|
||||||
"operation_pending",
|
"operation_pending",
|
||||||
"operation_in_progress",
|
"operation_in_progress",
|
||||||
"input_validation_error",
|
"input_validation_error"
|
||||||
"web_fetch",
|
|
||||||
"bash_exec",
|
|
||||||
"operation_status"
|
|
||||||
],
|
],
|
||||||
"title": "ResponseType",
|
"title": "ResponseType",
|
||||||
"description": "Types of tool responses."
|
"description": "Types of tool responses."
|
||||||
|
|||||||
@@ -180,14 +180,3 @@ body[data-google-picker-open="true"] [data-dialog-content] {
|
|||||||
z-index: 1 !important;
|
z-index: 1 !important;
|
||||||
pointer-events: none !important;
|
pointer-events: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* CoPilot chat table styling — remove left/right borders, increase padding */
|
|
||||||
[data-streamdown="table-wrapper"] table {
|
|
||||||
border-left: none;
|
|
||||||
border-right: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
[data-streamdown="table-wrapper"] th,
|
|
||||||
[data-streamdown="table-wrapper"] td {
|
|
||||||
padding: 0.875rem 1rem; /* py-3.5 px-4 */
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import { loadScript } from "@/services/scripts/scripts";
|
|||||||
export async function loadGoogleAPIPicker(): Promise<void> {
|
export async function loadGoogleAPIPicker(): Promise<void> {
|
||||||
validateWindow();
|
validateWindow();
|
||||||
|
|
||||||
await loadScript("https://apis.google.com/js/api.js");
|
await loadScript("https://apis.google.com/js/api.js", {
|
||||||
|
referrerPolicy: "no-referrer-when-downgrade",
|
||||||
|
});
|
||||||
|
|
||||||
const googleAPI = window.gapi;
|
const googleAPI = window.gapi;
|
||||||
if (!googleAPI) {
|
if (!googleAPI) {
|
||||||
@@ -27,7 +29,9 @@ export async function loadGoogleIdentityServices(): Promise<void> {
|
|||||||
throw new Error("Google Identity Services cannot load on server");
|
throw new Error("Google Identity Services cannot load on server");
|
||||||
}
|
}
|
||||||
|
|
||||||
await loadScript("https://accounts.google.com/gsi/client");
|
await loadScript("https://accounts.google.com/gsi/client", {
|
||||||
|
referrerPolicy: "no-referrer-when-downgrade",
|
||||||
|
});
|
||||||
|
|
||||||
const google = window.google;
|
const google = window.google;
|
||||||
if (!google?.accounts?.oauth2) {
|
if (!google?.accounts?.oauth2) {
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ function renderMarkdown(
|
|||||||
table: ({ children, ...props }) => (
|
table: ({ children, ...props }) => (
|
||||||
<div className="my-4 overflow-x-auto">
|
<div className="my-4 overflow-x-auto">
|
||||||
<table
|
<table
|
||||||
className="min-w-full divide-y divide-gray-200 border-y border-gray-200 dark:divide-gray-700 dark:border-gray-700"
|
className="min-w-full divide-y divide-gray-200 rounded-lg border border-gray-200 dark:divide-gray-700 dark:border-gray-700"
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
@@ -235,7 +235,7 @@ function renderMarkdown(
|
|||||||
),
|
),
|
||||||
th: ({ children, ...props }) => (
|
th: ({ children, ...props }) => (
|
||||||
<th
|
<th
|
||||||
className="bg-gray-50 px-4 py-3.5 text-left text-xs font-semibold uppercase tracking-wider text-gray-700 dark:bg-gray-800 dark:text-gray-300"
|
className="bg-gray-50 px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-gray-700 dark:bg-gray-800 dark:text-gray-300"
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
@@ -243,7 +243,7 @@ function renderMarkdown(
|
|||||||
),
|
),
|
||||||
td: ({ children, ...props }) => (
|
td: ({ children, ...props }) => (
|
||||||
<td
|
<td
|
||||||
className="border-t border-gray-200 px-4 py-3.5 text-sm text-gray-600 dark:border-gray-700 dark:text-gray-400"
|
className="border-t border-gray-200 px-4 py-3 text-sm text-gray-600 dark:border-gray-700 dark:text-gray-400"
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
|
|||||||
Reference in New Issue
Block a user