mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
23 Commits
feat/track
...
feat/githu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88eaab2baa | ||
|
|
4b0a445635 | ||
|
|
36312d2c6e | ||
|
|
d6d3b8d710 | ||
|
|
17d8d0bf05 | ||
|
|
5a2ab65f41 | ||
|
|
81a318de3e | ||
|
|
62c8e8634b | ||
|
|
b91c959cd9 | ||
|
|
5b95a2a1ef | ||
|
|
9c2a601167 | ||
|
|
b98e37bf23 | ||
|
|
fec8924361 | ||
|
|
712aee7302 | ||
|
|
bef292033e | ||
|
|
ec6974e3b8 | ||
|
|
2ef5e2fe77 | ||
|
|
0a8c7221ce | ||
|
|
840d1de636 | ||
|
|
ac55ab619b | ||
|
|
a8014d1e92 | ||
|
|
7de13c7713 | ||
|
|
9358b525a0 |
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -5,14 +5,12 @@ on:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
merge_group:
|
||||
|
||||
169
.github/workflows/platform-frontend-ci.yml
vendored
169
.github/workflows/platform-frontend-ci.yml
vendored
@@ -120,6 +120,175 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
314
.github/workflows/platform-fullstack-ci.yml
vendored
314
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,18 +1,14 @@
|
||||
name: AutoGPT Platform - Full-stack CI
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
@@ -28,28 +24,42 @@ defaults:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Node
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies to populate cache
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
check-api-types:
|
||||
name: check API types
|
||||
runs-on: ubuntu-latest
|
||||
types:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -57,256 +67,70 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
# ------------------------ Backend setup ------------------------
|
||||
|
||||
- name: Set up Backend - Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Backend - Install Poetry
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Installing Poetry version ${POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
|
||||
- name: Set up Backend - Set up dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Set up Backend - Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Set up Backend - Generate Prisma client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Set up Frontend - Export OpenAPI schema from Backend
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
|
||||
# ------------------------ Frontend setup ------------------------
|
||||
|
||||
- name: Set up Frontend - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Frontend - Set up Node
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up Frontend - Install dependencies
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up Frontend - Format OpenAPI schema
|
||||
id: format-schema
|
||||
run: pnpm prettier --write ./src/app/api/openapi.json
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after exporting the API schema."
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "\nIn the backend directory:"
|
||||
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
|
||||
echo "\nIn the frontend directory:"
|
||||
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
|
||||
echo "3. Run 'pnpm generate:api'"
|
||||
echo "4. Run 'pnpm types'"
|
||||
echo "5. Fix any TypeScript errors that may have been introduced"
|
||||
echo "6. Commit and push your changes"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Set up Frontend - Generate API client
|
||||
id: generate-api-client
|
||||
run: pnpm orval --config ./orval.config.ts
|
||||
# Continue with type generation & check even if there are schema changes
|
||||
if: success() || (steps.format-schema.outcome == 'success')
|
||||
|
||||
- name: Check for TypeScript errors
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
if: success() || (steps.generate-api-client.outcome == 'success')
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -27,12 +27,6 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -126,8 +120,6 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -215,7 +207,7 @@ async def list_sessions(
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; defaulting to empty"
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
@@ -237,7 +229,7 @@ async def list_sessions(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -356,7 +348,7 @@ async def update_session_title_route(
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
@@ -397,10 +389,6 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -408,25 +396,6 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/usage",
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
"""
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -436,7 +405,7 @@ async def get_copilot_usage(
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
@@ -481,7 +450,7 @@ async def cancel_session_task(
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
@@ -498,7 +467,7 @@ async def stream_chat_post(
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Authenticated user ID.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
@@ -507,7 +476,9 @@ async def stream_chat_post(
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
@@ -525,18 +496,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
if user_id:
|
||||
try:
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
@@ -771,7 +730,7 @@ async def stream_chat_post(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Resume an active stream for a session.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
@@ -252,156 +251,6 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
# Ensure the rate-limit branch is entered by setting a non-zero limit.
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
resets_at = datetime.now(UTC) + timedelta(days=3)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("weekly", resets_at),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"].lower()
|
||||
assert "weekly" in detail
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded(
|
||||
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"]
|
||||
assert "2h" in detail
|
||||
assert "Resets in" in detail
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_rejects_unauthenticated_request() -> None:
|
||||
"""GET /usage should return 401 when no valid JWT is provided."""
|
||||
unauthenticated_app = fastapi.FastAPI()
|
||||
unauthenticated_app.include_router(chat_routes.router)
|
||||
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
|
||||
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -36,15 +36,13 @@ from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_get_openai_client,
|
||||
client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -91,7 +89,7 @@ async def _compress_session_messages(
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=_get_openai_client(),
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
@@ -223,10 +221,6 @@ async def stream_chat_completion_baseline(
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
@@ -238,31 +232,16 @@ async def stream_chat_completion_baseline(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
# NOTE: stream_options={"include_usage": True} is not
|
||||
# universally supported — some providers (Mistral, Llama
|
||||
# via OpenRouter) always return chunk.usage=None. When
|
||||
# that happens, tokens stay 0 and the tiktoken fallback
|
||||
# below activates. Fail-open: one round is estimated.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
@@ -415,7 +394,6 @@ async def stream_chat_completion_baseline(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
@@ -433,49 +411,6 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Fallback: estimate tokens via tiktoken when the provider does
|
||||
# not honour stream_options={"include_usage": True}.
|
||||
# Count the full message list (system + history + turn) since
|
||||
# each API call sends the complete context window.
|
||||
# NOTE: This estimates one round's prompt tokens. Multi-round tool-calling
|
||||
# turns consume prompt tokens on each API call, so the total is underestimated.
|
||||
# Skip fallback when an error occurred and no output was produced —
|
||||
# charging rate-limit tokens for completely failed requests is unfair.
|
||||
if (
|
||||
turn_prompt_tokens == 0
|
||||
and turn_completion_tokens == 0
|
||||
and not (_stream_error and not assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
turn_completion_tokens = estimate_token_count_str(
|
||||
assistant_text, model=config.model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
|
||||
# cannot break out cache_read/cache_creation weights. Users on the
|
||||
# baseline path may be slightly over-counted vs the SDK path.
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
@@ -486,16 +421,4 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -70,27 +70,6 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Per-turn token cost varies with context size: ~10-15K for early turns,
|
||||
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
|
||||
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
|
||||
162
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
162
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps provider slug → env var names to inject when the provider is connected.
|
||||
# Add new providers here when adding integration support.
|
||||
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
|
||||
# must be updated when adding a new provider.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"falling back to potentially stale token",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "stale-oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -73,9 +73,6 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
@@ -101,10 +98,7 @@ class ChatSessionInfo(BaseModel):
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts.
|
||||
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
|
||||
# is lost after persistence — the DB only stores aggregate prompt and
|
||||
# completion totals. This is a known limitation.
|
||||
# Calculate usage from token counts
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
|
||||
@@ -95,6 +95,25 @@ Example — committing an image file to GitHub:
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
@@ -105,6 +124,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -119,6 +139,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -152,12 +173,16 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -175,7 +200,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -190,6 +219,7 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
def __init__(self, window: str, resets_at: datetime):
|
||||
self.window = window
|
||||
self.resets_at = resets_at
|
||||
delta = resets_at - datetime.now(UTC)
|
||||
total_secs = delta.total_seconds()
|
||||
if total_secs <= 0:
|
||||
time_str = "now"
|
||||
else:
|
||||
hours = int(total_secs // 3600)
|
||||
minutes = int((total_secs % 3600) // 60)
|
||||
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||
super().__init__(
|
||||
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||
)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
daily_used = 0
|
||||
weekly_used = 0
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
daily_used = int(daily_raw or 0)
|
||||
weekly_used = int(weekly_raw or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for usage status, returning zeros")
|
||||
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
|
||||
# round-trip entirely.
|
||||
if daily_token_limit <= 0 and weekly_token_limit <= 0:
|
||||
return
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
daily_used = int(daily_raw or 0)
|
||||
weekly_used = int(weekly_raw or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for rate limit check, allowing request")
|
||||
return
|
||||
|
||||
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# transaction=False: these are independent INCRBY+EXPIRE pairs on
|
||||
# separate keys — no cross-key atomicity needed. Skipping
|
||||
# MULTI/EXEC avoids the overhead. If the connection drops between
|
||||
# INCRBY and EXPIRE the key survives until the next date-based key
|
||||
# rotation (daily/weekly), so the memory-leak risk is negligible.
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return f"{_USAGE_KEY_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
|
||||
|
||||
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
year, week, _ = now.isocalendar()
|
||||
return f"{_USAGE_KEY_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||
|
||||
|
||||
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
|
||||
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current weekly window resets (next Monday 00:00 UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||
days=days_until_monday
|
||||
)
|
||||
@@ -1,334 +0,0 @@
|
||||
"""Unit tests for CoPilot rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitExceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
def test_message_contains_window_name(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert "daily" in str(exc)
|
||||
|
||||
def test_message_contains_reset_time(self):
|
||||
exc = RateLimitExceeded(
|
||||
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
)
|
||||
msg = str(exc)
|
||||
# Allow for slight timing drift (29m or 30m)
|
||||
assert "2h " in msg
|
||||
assert "Resets in" in msg
|
||||
|
||||
def test_message_minutes_only_when_under_one_hour(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||
msg = str(exc)
|
||||
assert "Resets in" in msg
|
||||
# Should not have "0h"
|
||||
assert "0h" not in msg
|
||||
|
||||
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||
assert "Resets in now" in str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_usage_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUsageStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
assert status.daily.used == 500
|
||||
assert status.daily.limit == 10000
|
||||
assert status.weekly.used == 2000
|
||||
assert status.weekly.limit == 50000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zeros_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_daily_counter(self):
|
||||
"""Daily counter is None (new day), weekly has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_weekly_counter(self):
|
||||
"""Weekly counter is None (start of week), daily has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
# Daily reset should be within 24h
|
||||
assert status.daily.resets_at > now
|
||||
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRateLimit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_under_limit(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_daily_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_weekly_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_redis_unavailable(self):
|
||||
"""Fail-open: allow requests when Redis is down."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_check_when_limit_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_redis_counters(self):
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
|
||||
# Daily key TTL should be positive (seconds until next midnight)
|
||||
daily_ttl = expire_calls[0].args[1]
|
||||
assert daily_ttl >= 1
|
||||
|
||||
# Weekly key TTL should be positive (seconds until next Monday)
|
||||
weekly_ttl = expire_calls[1].args[1]
|
||||
assert weekly_ttl >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_failure_gracefully(self):
|
||||
"""Should not raise when Redis is unavailable."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
@@ -186,43 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="promptTokens",
|
||||
description="Number of uncached prompt tokens",
|
||||
)
|
||||
completion_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="completionTokens",
|
||||
description="Number of completion tokens",
|
||||
)
|
||||
total_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="totalTokens",
|
||||
description="Total number of tokens (raw, not weighted)",
|
||||
)
|
||||
cache_read_tokens: int = Field(
|
||||
default=0,
|
||||
serialization_alias="cacheReadTokens",
|
||||
description="Prompt tokens served from cache (10% cost)",
|
||||
)
|
||||
cache_creation_tokens: int = Field(
|
||||
default=0,
|
||||
serialization_alias="cacheCreationTokens",
|
||||
description="Prompt tokens written to cache (25% cost)",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True, by_alias=True)}\n\n"
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
|
||||
@@ -55,14 +55,12 @@ from ..response_model import (
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
_is_langfuse_configured,
|
||||
)
|
||||
from ..token_tracking import persist_and_record_usage
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
@@ -738,13 +736,6 @@ async def stream_chat_completion_sdk(
|
||||
_otel_ctx: Any = None
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
# Token usage accumulators — populated from ResultMessage at end of turn
|
||||
turn_prompt_tokens = 0 # uncached input tokens only
|
||||
turn_completion_tokens = 0
|
||||
turn_cache_read_tokens = 0
|
||||
turn_cache_creation_tokens = 0
|
||||
turn_cost_usd: float | None = None
|
||||
|
||||
try:
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support).
|
||||
# Pre-compute the cwd here so the exact working directory path can be
|
||||
@@ -778,7 +769,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
return None
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
sandbox = await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
@@ -792,7 +783,9 @@ async def stream_chat_completion_sdk(
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
return None
|
||||
|
||||
return sandbox
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
@@ -1121,7 +1114,7 @@ async def stream_chat_completion_sdk(
|
||||
- len(adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Log ResultMessage details and capture token usage
|
||||
# Log ResultMessage details for debugging
|
||||
if isinstance(sdk_msg, ResultMessage):
|
||||
logger.info(
|
||||
"%s Received: ResultMessage %s "
|
||||
@@ -1140,33 +1133,6 @@ async def stream_chat_completion_sdk(
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
# input_tokens = uncached only
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||
turn_cache_read_tokens += sdk_msg.usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
)
|
||||
turn_cache_creation_tokens += sdk_msg.usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
turn_completion_tokens += sdk_msg.usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
|
||||
log_prefix,
|
||||
turn_prompt_tokens,
|
||||
turn_cache_read_tokens,
|
||||
turn_cache_creation_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
if sdk_msg.total_cost_usd is not None:
|
||||
turn_cost_usd = sdk_msg.total_cost_usd
|
||||
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# When compaction ends, sync TranscriptBuilder with the
|
||||
# CLI's active context so they stay identical.
|
||||
@@ -1383,26 +1349,6 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# Emit token usage to the client (must be in try to reach SSE stream).
|
||||
# Session persistence of usage is in finally to stay consistent with
|
||||
# rate-limit recording even if an exception interrupts between here
|
||||
# and the finally block.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
# total_tokens = prompt (uncached input) + completion (output).
|
||||
# Cache tokens are tracked separately and excluded from total
|
||||
# so that the semantics match the baseline path (OpenRouter)
|
||||
# which folds cache into prompt_tokens. Keeping total_tokens
|
||||
# = prompt + completion everywhere makes cross-path comparisons
|
||||
# and session-level aggregation consistent.
|
||||
total_tokens = turn_prompt_tokens + turn_completion_tokens
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
)
|
||||
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
# to avoid double-uploads (the success path used to upload the
|
||||
# old resume file, then the finally block overwrote it with the
|
||||
@@ -1467,20 +1413,6 @@ async def stream_chat_completion_sdk(
|
||||
except Exception:
|
||||
logger.warning("OTEL context teardown failed", exc_info=True)
|
||||
|
||||
# --- Persist token usage to session + rate-limit counters ---
|
||||
# Both must live in finally so they stay consistent even when an
|
||||
# exception interrupts the try block after StreamUsage was yielded.
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
# This MUST run in finally to persist messages even when the generator
|
||||
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
|
||||
|
||||
@@ -28,24 +28,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
|
||||
_client: LangfuseAsyncOpenAI | None = None
|
||||
_langfuse = None
|
||||
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
def _get_openai_client() -> LangfuseAsyncOpenAI:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
return _client
|
||||
|
||||
|
||||
def _get_langfuse():
|
||||
global _langfuse
|
||||
if _langfuse is None:
|
||||
_langfuse = get_client()
|
||||
return _langfuse
|
||||
|
||||
langfuse = get_client()
|
||||
|
||||
# Default system prompt used when Langfuse is not configured
|
||||
# Provides minimal baseline tone and personality - all workflow, tools, and
|
||||
@@ -98,7 +84,7 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
else "latest"
|
||||
)
|
||||
prompt = await asyncio.to_thread(
|
||||
_get_langfuse().get_prompt,
|
||||
langfuse.get_prompt,
|
||||
config.langfuse_prompt_name,
|
||||
label=label,
|
||||
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
||||
@@ -172,7 +158,7 @@ async def _generate_session_title(
|
||||
"environment": settings.config.app_env.value,
|
||||
}
|
||||
|
||||
response = await _get_openai_client().chat.completions.create(
|
||||
response = await client.chat.completions.create(
|
||||
model=config.title_model,
|
||||
messages=[
|
||||
{
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
"""Shared token-usage persistence and rate-limit recording.
|
||||
|
||||
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def persist_and_record_usage(
|
||||
*,
|
||||
session: ChatSession | None,
|
||||
user_id: str | None,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
log_prefix: str = "",
|
||||
cost_usd: float | str | None = None,
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
Args:
|
||||
session: The chat session to append usage to (may be None on error).
|
||||
user_id: User ID for rate-limit counters (skipped if None).
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
|
||||
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
if prompt_tokens <= 0 and completion_tokens <= 0:
|
||||
return 0
|
||||
|
||||
# total_tokens = prompt + completion. Cache tokens are tracked
|
||||
# separately and excluded from total so both baseline and SDK
|
||||
# paths share the same semantics.
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if session is not None:
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
if cache_read_tokens or cache_creation_tokens:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
|
||||
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
|
||||
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
|
||||
f"completion={completion_tokens}, total={total_tokens}"
|
||||
)
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
|
||||
return total_tokens
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Unit tests for token_tracking.persist_and_record_usage.
|
||||
|
||||
Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
|
||||
calling conventions, session persistence, and rate-limit recording.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .token_tracking import persist_and_record_usage
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
"""Return a minimal in-memory ChatSession for testing."""
|
||||
return ChatSession(
|
||||
session_id="sess-test",
|
||||
user_id="user-test",
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Return value / total_tokens semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTotalTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_prompt_plus_completion(self):
|
||||
"""total_tokens = prompt + completion (cache excluded from total)."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=300,
|
||||
completion_tokens=200,
|
||||
)
|
||||
assert total == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zero_when_no_tokens(self):
|
||||
"""Returns 0 early when both prompt and completion are zero."""
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_excluded_from_total(self):
|
||||
"""Cache tokens are stored separately and not added to total_tokens."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=5000,
|
||||
cache_creation_tokens=200,
|
||||
)
|
||||
# total = prompt + completion only (5000 + 200 cache excluded)
|
||||
assert total == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_path_no_cache(self):
|
||||
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="u1",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=400,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
assert total == 1400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_path_with_cache(self):
|
||||
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="u2",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=8000,
|
||||
cache_creation_tokens=400,
|
||||
log_prefix="[SDK]",
|
||||
cost_usd=0.0015,
|
||||
)
|
||||
assert total == 300
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_appends_usage_to_session(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert len(session.usage) == 1
|
||||
usage: Usage = session.usage[0]
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.cache_read_tokens == 0
|
||||
assert usage.cache_creation_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_appends_cache_breakdown_to_session(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=None,
|
||||
prompt_tokens=200,
|
||||
completion_tokens=80,
|
||||
cache_read_tokens=3000,
|
||||
cache_creation_tokens=500,
|
||||
)
|
||||
usage: Usage = session.usage[0]
|
||||
assert usage.cache_read_tokens == 3000
|
||||
assert usage.cache_creation_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_turns_append_multiple_records(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=session, user_id=None, prompt_tokens=100, completion_tokens=50
|
||||
)
|
||||
await persist_and_record_usage(
|
||||
session=session, user_id=None, prompt_tokens=200, completion_tokens=70
|
||||
)
|
||||
assert len(session.usage) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_session_does_not_raise(self):
|
||||
"""When session is None (e.g. error path), no exception should be raised."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert total == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_append_when_zero_tokens(self):
|
||||
"""When tokens are zero, function returns early — session unchanged."""
|
||||
session = _make_session()
|
||||
total = await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
assert total == 0
|
||||
assert len(session.usage) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate-limit recording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitRecording:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_record_token_usage_when_user_id_present(self):
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-abc",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
)
|
||||
mock_record.assert_awaited_once_with(
|
||||
user_id="user-abc",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_user_id_is_none(self):
|
||||
"""Anonymous sessions should not create Redis keys."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_failure_does_not_raise(self):
|
||||
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
|
||||
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
# Should not raise
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-xyz",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert total == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_zero_tokens(self):
|
||||
"""Returns 0 before calling record_token_usage when tokens are zero."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-abc",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -96,7 +97,9 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -133,14 +136,27 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
envs.update(integration_env)
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
envs=envs,
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _ProviderInfo(TypedDict):
|
||||
name: str
|
||||
types: list[str]
|
||||
# Default OAuth scopes requested when the agent doesn't specify any.
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
provider: str
|
||||
provider_name: str
|
||||
type: str
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Evaluated lazily (not at import time) to avoid triggering Secrets() during
|
||||
module import, which can fail in environments where secrets are not loaded.
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# Registry of known providers: name + supported credential types for the UI.
|
||||
# When adding a new provider, also add its env var names to
|
||||
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
|
||||
def _get_provider_info() -> dict[str, _ProviderInfo]:
|
||||
"""Build the provider registry, evaluating OAuth config lazily."""
|
||||
return {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"types": (
|
||||
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
|
||||
),
|
||||
# Default: repo scope covers clone/push/pull for public and private repos.
|
||||
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
|
||||
"scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(_get_provider_info().keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
provider_info = _get_provider_info()
|
||||
info = provider_info.get(provider)
|
||||
if not info:
|
||||
supported = ", ".join(f"'{p}'" for p in provider_info)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
provider_name: str = info["name"]
|
||||
supported_types: list[str] = info["types"]
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = info["scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {provider_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{provider_name} Credentials",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"type": supported_types[0],
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=provider_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -8,13 +8,11 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.db_accessors import credit_db, workspace_db
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
@@ -117,21 +115,6 @@ async def execute_block(
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
|
||||
# Pre-execution credit check (courtesy; spend_credits is atomic)
|
||||
cost, cost_filter = block_usage_cost(block, input_data)
|
||||
has_cost = cost > 0
|
||||
_credit_db = credit_db()
|
||||
if has_cost:
|
||||
balance = await _credit_db.get_credits(user_id)
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to run '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
@@ -140,51 +123,6 @@ async def execute_block(
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _credit_db.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=block_id,
|
||||
block=block.name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Block already executed (with possible side effects). Never
|
||||
# return ErrorResponse here — the user received output and
|
||||
# deserves it. Log the billing failure for reconciliation.
|
||||
leak_type = (
|
||||
"INSUFFICIENT_BALANCE"
|
||||
if isinstance(e, InsufficientBalanceError)
|
||||
else "UNEXPECTED_ERROR"
|
||||
)
|
||||
logger.error(
|
||||
"BILLING_LEAK[%s]: block executed but credit charge failed — "
|
||||
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
|
||||
leak_type,
|
||||
user_id,
|
||||
block_id,
|
||||
node_exec_id,
|
||||
cost,
|
||||
e,
|
||||
extra={
|
||||
"json_fields": {
|
||||
"billing_leak": True,
|
||||
"leak_type": leak_type,
|
||||
"user_id": user_id,
|
||||
"cost": str(cost),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
@@ -195,14 +133,14 @@ async def execute_block(
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning("Block execution failed: %s", e)
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
|
||||
@@ -1,197 +1,18 @@
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
"""Tests for execute_block type coercion in helpers.py.
|
||||
|
||||
Verifies that execute_block() coerces string input values to match the block's
|
||||
expected input types, mirroring the executor's validate_exec() logic.
|
||||
This is critical for @@agptfile: expansion, where file content is always a string
|
||||
but the block may expect structured types (e.g. list[list[str]]).
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
|
||||
|
||||
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||
"""Create a minimal mock block for execute_block()."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
async def _execute(
|
||||
input_data: dict, **kwargs: Any
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
yield "result", "ok"
|
||||
|
||||
mock.execute = _execute
|
||||
return mock
|
||||
|
||||
|
||||
def _patch_workspace():
|
||||
"""Patch workspace_db to return a mock workspace."""
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_ws_db = MagicMock()
|
||||
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||
|
||||
|
||||
def _patch_credit_db(
|
||||
get_credits_return: int = 100,
|
||||
spend_credits_side_effect: Any = None,
|
||||
):
|
||||
"""Patch credit_db accessor to return a mock credit adapter."""
|
||||
mock_credit = MagicMock()
|
||||
mock_credit.get_credits = AsyncMock(return_value=get_credits_return)
|
||||
if spend_credits_side_effect is not None:
|
||||
mock_credit.spend_credits = AsyncMock(side_effect=spend_credits_side_effect)
|
||||
else:
|
||||
mock_credit.spend_credits = AsyncMock()
|
||||
return (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.credit_db",
|
||||
return_value=mock_credit,
|
||||
),
|
||||
mock_credit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credit charging tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestExecuteBlockCreditCharging:
|
||||
async def test_charges_credits_when_cost_is_positive(self):
|
||||
"""Block with cost > 0 should call spend_credits after execution."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(get_credits_return=100)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {"key": "val"}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
mock_credit.spend_credits.assert_awaited_once()
|
||||
call_kwargs = mock_credit.spend_credits.call_args.kwargs
|
||||
assert call_kwargs["cost"] == 10
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(get_credits_return=5)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
async def test_no_charge_when_cost_is_zero(self):
|
||||
"""Block with cost 0 should not call spend_credits."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Credit functions should not be called at all for zero-cost blocks
|
||||
mock_credit.get_credits.assert_not_awaited()
|
||||
mock_credit.spend_credits.assert_not_awaited()
|
||||
|
||||
async def test_returns_output_on_post_exec_insufficient_balance(self):
|
||||
"""If charging fails after execution, output is still returned (block already ran)."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(
|
||||
get_credits_return=15,
|
||||
spend_credits_side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
# Block already executed (with side effects), so output is returned
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type coercion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
from backend.copilot.tools.models import BlockOutputResponse
|
||||
|
||||
|
||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
@@ -207,7 +28,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
return schema
|
||||
|
||||
|
||||
def _make_coerce_block(
|
||||
def _make_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
annotations: dict[str, Any],
|
||||
@@ -239,7 +60,7 @@ _TEST_USER_ID = "test-user-coerce"
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
{"values": list[list[str]], "spreadsheet_id": str},
|
||||
@@ -282,7 +103,7 @@ async def test_coerce_json_string_to_nested_list():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
{"items": list[str]},
|
||||
@@ -314,7 +135,7 @@ async def test_coerce_json_string_to_list():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
{"config": dict[str, str]},
|
||||
@@ -346,7 +167,7 @@ async def test_coerce_json_string_to_dict():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_coercion_when_type_matches():
|
||||
"""Already-correct types pass through without coercion."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"pass-through",
|
||||
"Pass Through",
|
||||
{"values": list[list[str]], "name": str},
|
||||
@@ -380,7 +201,7 @@ async def test_no_coercion_when_type_matches():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
{"count": int},
|
||||
@@ -413,7 +234,7 @@ async def test_coerce_string_to_int():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_skips_none_values():
|
||||
"""None values are not coerced (they may be optional fields)."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"optional-block",
|
||||
"Optional Block",
|
||||
{"data": list[str], "label": str},
|
||||
@@ -446,7 +267,7 @@ async def test_coerce_skips_none_values():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_union_type_preserves_valid_member():
|
||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"union-block",
|
||||
"Union Block",
|
||||
{"content": str | list[str]},
|
||||
@@ -480,7 +301,7 @@ async def test_coerce_union_type_preserves_valid_member():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_inner_elements_of_generic():
|
||||
"""Inner elements of generic containers are recursively coerced."""
|
||||
block = _make_coerce_block(
|
||||
block = _make_block(
|
||||
"inner-coerce",
|
||||
"Inner Coerce",
|
||||
{"values": list[str]},
|
||||
|
||||
@@ -129,16 +129,3 @@ def review_db():
|
||||
review_db = get_database_manager_async_client()
|
||||
|
||||
return review_db
|
||||
|
||||
|
||||
def credit_db():
|
||||
if db.is_connected():
|
||||
from backend.data import db_manager as _credit_db
|
||||
|
||||
credit_db = _credit_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
@@ -148,11 +148,6 @@ async def _get_credits(user_id: str) -> int:
|
||||
return await user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
# Public aliases used by db_accessors.credit_db() when Prisma is connected
|
||||
get_credits = _get_credits
|
||||
spend_credits = _spend_credits
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""Database connection pooling service.
|
||||
|
||||
@@ -517,10 +512,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Credits ============ #
|
||||
spend_credits = d.spend_credits
|
||||
get_credits = d.get_credits
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
@@ -25,6 +25,35 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time; calling this again replaces the
|
||||
previous hook. Intended to be called once at application startup by the
|
||||
copilot module to bust its token cache without creating an import cycle.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def _bust_copilot_cache(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered hook (if any) to bust downstream token caches."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -69,7 +98,11 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Bust the copilot token cache so that the next bash_exec picks up the
|
||||
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
|
||||
_bust_copilot_cache(user_id, credentials.provider)
|
||||
return result
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -156,6 +189,8 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Bust copilot cache so the refreshed token is picked up immediately.
|
||||
_bust_copilot_cache(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
@@ -168,10 +203,17 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Bust the copilot token cache so the updated credential is picked up immediately.
|
||||
_bust_copilot_cache(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_bust_copilot_cache(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { UploadSimple } from "@phosphor-icons/react";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
@@ -83,9 +89,10 @@ export function CopilotPage() {
|
||||
handleDrawerOpenChange,
|
||||
handleSelectSession,
|
||||
handleNewChat,
|
||||
// Delete functionality (available via ChatSidebar context menu on all viewports)
|
||||
// Delete functionality
|
||||
sessionToDelete,
|
||||
isDeleting,
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
} = useCopilotPage();
|
||||
@@ -141,6 +148,38 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1.5 hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-5 w-5 text-neutral-600" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={() => {
|
||||
const session = sessions.find(
|
||||
(s) => s.id === sessionId,
|
||||
);
|
||||
if (session) {
|
||||
handleDeleteClick(session.id, session.title);
|
||||
}
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { LayoutGroup, motion } from "framer-motion";
|
||||
import { ReactNode } from "react";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { EmptySession } from "../EmptySession/EmptySession";
|
||||
@@ -20,6 +21,7 @@ export interface ChatContainerProps {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -38,6 +40,7 @@ export const ChatContainer = ({
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
@@ -60,6 +63,7 @@ export const ChatContainer = ({
|
||||
status={status}
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
|
||||
@@ -30,6 +30,7 @@ interface Props {
|
||||
status: string;
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
@@ -101,6 +102,7 @@ export function ChatMessagesContainer({
|
||||
status,
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
@@ -133,6 +135,7 @@ export function ChatMessagesContainer({
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
{headerSlot}
|
||||
{isLoading && messages.length === 0 && (
|
||||
<div
|
||||
className="flex flex-1 items-center justify-center"
|
||||
|
||||
@@ -3,6 +3,7 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -129,6 +130,8 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-search_docs":
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-connect_integration":
|
||||
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -37,7 +37,6 @@ import { useCopilotUIStore } from "../../store";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { UsageLimits } from "../UsageLimits/UsageLimits";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
@@ -257,10 +256,11 @@ export function ChatSidebar() {
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="flex items-center">
|
||||
<UsageLimits />
|
||||
<div className="relative left-5 flex items-center gap-1">
|
||||
<NotificationToggle />
|
||||
<SidebarTrigger />
|
||||
<div className="relative left-1">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{sessionId ? (
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||
import { useCopilotUIStore } from "../../../../store";
|
||||
@@ -49,7 +48,10 @@ export function NotificationToggle() {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Notification settings">
|
||||
<button
|
||||
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
|
||||
aria-label="Notification settings"
|
||||
>
|
||||
{!isNotificationsEnabled ? (
|
||||
<BellSlash className="!size-5" />
|
||||
) : isSoundEnabled ? (
|
||||
@@ -57,7 +59,7 @@ export function NotificationToggle() {
|
||||
) : (
|
||||
<Bell className="!size-5" />
|
||||
)}
|
||||
</Button>
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-56 p-3">
|
||||
<div className="flex flex-col gap-3">
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChartBar } from "@phosphor-icons/react";
|
||||
import { UsagePanelContent } from "./UsagePanelContent";
|
||||
|
||||
export { UsagePanelContent, formatResetTime } from "./UsagePanelContent";
|
||||
|
||||
export function UsageLimits() {
|
||||
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Usage limits">
|
||||
<ChartBar className="!size-5" weight="light" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-64 p-3">
|
||||
<UsagePanelContent usage={usage} />
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import Link from "next/link";
|
||||
|
||||
export function formatResetTime(
|
||||
resetsAt: Date | string,
|
||||
now: Date = new Date(),
|
||||
): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / (1000 * 60 * 60));
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<span className="text-xs font-medium text-neutral-700">{label}</span>
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsagePanelContent({
|
||||
usage,
|
||||
showBillingLink = true,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
showBillingLink?: boolean;
|
||||
}) {
|
||||
const hasDailyLimit = usage.daily.limit > 0;
|
||||
const hasWeeklyLimit = usage.weekly.limit > 0;
|
||||
|
||||
if (!hasDailyLimit && !hasWeeklyLimit) {
|
||||
return (
|
||||
<div className="text-xs text-neutral-500">No usage limits configured</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
|
||||
{hasDailyLimit && (
|
||||
<UsageBar
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{hasWeeklyLimit && (
|
||||
<UsageBar
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
{showBillingLink && (
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="text-[11px] text-blue-600 hover:underline"
|
||||
>
|
||||
Learn more about usage limits
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsageLimits } from "../UsageLimits";
|
||||
|
||||
// Mock the generated Orval hook
|
||||
const mockUseGetV2GetCopilotUsage = vi.fn();
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useGetV2GetCopilotUsage: (opts: unknown) => mockUseGetV2GetCopilotUsage(opts),
|
||||
}));
|
||||
|
||||
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
|
||||
vi.mock("@/components/molecules/Popover/Popover", () => ({
|
||||
Popover: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseGetV2GetCopilotUsage.mockReset();
|
||||
});
|
||||
|
||||
function makeUsage({
|
||||
dailyUsed = 500,
|
||||
dailyLimit = 10000,
|
||||
weeklyUsed = 2000,
|
||||
weeklyLimit = 50000,
|
||||
}: {
|
||||
dailyUsed?: number;
|
||||
dailyLimit?: number;
|
||||
weeklyUsed?: number;
|
||||
weeklyLimit?: number;
|
||||
} = {}) {
|
||||
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
|
||||
return {
|
||||
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
|
||||
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
|
||||
};
|
||||
}
|
||||
|
||||
describe("UsageLimits", () => {
|
||||
it("renders nothing while loading", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: true,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders nothing when no limits are configured", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
|
||||
isLoading: false,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders the usage button when limits exist", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays daily and weekly usage percentages", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("50% used")).toBeDefined();
|
||||
expect(screen.getByText("Today")).toBeDefined();
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows only weekly bar when daily limit is 0", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({
|
||||
dailyLimit: 0,
|
||||
weeklyUsed: 25000,
|
||||
weeklyLimit: 50000,
|
||||
}),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.queryByText("Today")).toBeNull();
|
||||
});
|
||||
|
||||
it("caps percentage at 100% when over limit", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows learn more link to credits page", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
const link = screen.getByText("Learn more about usage limits");
|
||||
expect(link).toBeDefined();
|
||||
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,104 @@
|
||||
"use client";
|
||||
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useState } from "react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
|
||||
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
|
||||
type Props = {
|
||||
part: ToolUIPart;
|
||||
};
|
||||
|
||||
function parseJson(raw: unknown): unknown {
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
return JSON.parse(raw);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
|
||||
return parsed as SetupRequirementsResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseError(raw: unknown): string | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "message" in parsed) {
|
||||
return String((parsed as { message: unknown }).message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ConnectIntegrationTool({ part }: Props) {
|
||||
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
|
||||
const [isDismissed, setIsDismissed] = useState(false);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output =
|
||||
part.state === "output-available"
|
||||
? parseOutput((part as { output?: unknown }).output)
|
||||
: null;
|
||||
|
||||
const errorMessage = isError
|
||||
? (parseError((part as { output?: unknown }).output) ??
|
||||
"Failed to connect integration")
|
||||
: null;
|
||||
|
||||
const rawProvider =
|
||||
(part as { input?: { provider?: string } }).input?.provider ?? "";
|
||||
const providerName =
|
||||
output?.setup_info?.agent_name ??
|
||||
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
|
||||
// prevent runaway text in the DOM.
|
||||
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
|
||||
|
||||
const label = isStreaming
|
||||
? `Connecting ${providerName}…`
|
||||
: isError
|
||||
? `Failed to connect ${providerName}`
|
||||
: output
|
||||
? `Connect ${output.setup_info?.agent_name ?? providerName}`
|
||||
: `Connect ${providerName}`;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<MorphingTextAnimation
|
||||
text={label}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isError && errorMessage && (
|
||||
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
|
||||
)}
|
||||
|
||||
{output && (
|
||||
<div className="mt-2">
|
||||
{isDismissed ? (
|
||||
<ContentMessage>Connected. Continuing…</ContentMessage>
|
||||
) : (
|
||||
<SetupRequirementsCard
|
||||
output={output}
|
||||
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
|
||||
retryInstruction="I've connected my account. Please continue."
|
||||
onComplete={() => setIsDismissed(true)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,12 +23,16 @@ interface Props {
|
||||
/** Override the label shown above the credentials section.
|
||||
* Defaults to "Credentials". */
|
||||
credentialsLabel?: string;
|
||||
/** Called after Proceed is clicked so the parent can persist the dismissed state
|
||||
* across remounts (avoids re-enabling the Proceed button on remount). */
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function SetupRequirementsCard({
|
||||
output,
|
||||
retryInstruction,
|
||||
credentialsLabel,
|
||||
onComplete,
|
||||
}: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
@@ -68,13 +72,17 @@ export function SetupRequirementsCard({
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
if (hasSent) {
|
||||
return <ContentMessage>Connected. Continuing…</ContentMessage>;
|
||||
}
|
||||
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
onComplete?.();
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
getGetV2GetCopilotUsageQueryKey,
|
||||
getGetV2GetSessionQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
@@ -178,41 +177,12 @@ export function useCopilotStream({
|
||||
onError: (error) => {
|
||||
if (!sessionId) return;
|
||||
|
||||
// Detect rate limit (429) responses and show reset time to the user.
|
||||
// The SDK throws a plain Error whose message is the raw response body
|
||||
// (FastAPI returns {"detail": "...usage limit..."} for 429s).
|
||||
let errorDetail: string = error.message;
|
||||
try {
|
||||
const parsed = JSON.parse(error.message) as unknown;
|
||||
if (
|
||||
typeof parsed === "object" &&
|
||||
parsed !== null &&
|
||||
"detail" in parsed &&
|
||||
typeof (parsed as { detail: unknown }).detail === "string"
|
||||
) {
|
||||
errorDetail = (parsed as { detail: string }).detail;
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — use message as-is
|
||||
}
|
||||
const isRateLimited = errorDetail.toLowerCase().includes("usage limit");
|
||||
if (isRateLimited) {
|
||||
toast({
|
||||
title: "Usage limit reached",
|
||||
description:
|
||||
errorDetail ||
|
||||
"You've reached your usage limit. Please try again later.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Detect authentication failures (from getAuthHeaders or 401 responses)
|
||||
const isAuthError =
|
||||
errorDetail.includes("Authentication failed") ||
|
||||
errorDetail.includes("Unauthorized") ||
|
||||
errorDetail.includes("Not authenticated") ||
|
||||
errorDetail.toLowerCase().includes("401");
|
||||
error.message.includes("Authentication failed") ||
|
||||
error.message.includes("Unauthorized") ||
|
||||
error.message.includes("Not authenticated") ||
|
||||
error.message.toLowerCase().includes("401");
|
||||
if (isAuthError) {
|
||||
toast({
|
||||
title: "Authentication error",
|
||||
@@ -337,9 +307,6 @@ export function useCopilotStream({
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetCopilotUsageQueryKey(),
|
||||
});
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
|
||||
@@ -11,9 +11,6 @@ import {
|
||||
|
||||
import { RefundModal } from "./RefundModal";
|
||||
import { CreditTransaction } from "@/lib/autogpt-server-api";
|
||||
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
import {
|
||||
Table,
|
||||
@@ -24,32 +21,6 @@ import {
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
function CoPilotUsageSection() {
|
||||
const router = useRouter();
|
||||
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<div className="my-6 space-y-4">
|
||||
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
|
||||
<div className="rounded-lg border border-neutral-200 p-4">
|
||||
<UsagePanelContent usage={usage} showBillingLink={false} />
|
||||
</div>
|
||||
<Button className="w-full" onClick={() => router.push("/copilot")}>
|
||||
Open CoPilot
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function CreditsPage() {
|
||||
const api = useBackendAPI();
|
||||
const {
|
||||
@@ -266,13 +237,11 @@ export default function CreditsPage() {
|
||||
</Button>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{/* CoPilot Usage Limits */}
|
||||
<CoPilotUsageSection />
|
||||
</div>
|
||||
|
||||
<div className="my-6 space-y-4">
|
||||
{/* Payment Portal */}
|
||||
|
||||
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
|
||||
<p className="text-neutral-600">
|
||||
You can manage your cards and see your payment history in the
|
||||
|
||||
@@ -1267,7 +1267,7 @@
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat Post",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"operationId": "postV2StreamChatPost",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1382,28 +1382,6 @@
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/usage": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Copilot Usage",
|
||||
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
|
||||
"operationId": "getV2GetCopilotUsage",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -8477,16 +8455,6 @@
|
||||
"title": "ClarifyingQuestion",
|
||||
"description": "A question that needs user clarification."
|
||||
},
|
||||
"CoPilotUsageStatus": {
|
||||
"properties": {
|
||||
"daily": { "$ref": "#/components/schemas/UsageWindow" },
|
||||
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["daily", "weekly"],
|
||||
"title": "CoPilotUsageStatus",
|
||||
"description": "Current usage status for a user across all windows."
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -12222,16 +12190,6 @@
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
},
|
||||
"total_prompt_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Prompt Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"total_completion_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Completion Tokens",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -14629,25 +14587,6 @@
|
||||
"required": ["timezone"],
|
||||
"title": "UpdateTimezoneRequest"
|
||||
},
|
||||
"UsageWindow": {
|
||||
"properties": {
|
||||
"used": { "type": "integer", "title": "Used" },
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"title": "Limit",
|
||||
"description": "Maximum tokens allowed in this window. 0 means unlimited."
|
||||
},
|
||||
"resets_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Resets At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["used", "limit", "resets_at"],
|
||||
"title": "UsageWindow",
|
||||
"description": "Usage within a single time window."
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
|
||||
@@ -125,9 +125,9 @@ export function useCredentialsInput({
|
||||
if (hasAttemptedAutoSelect.current) return;
|
||||
hasAttemptedAutoSelect.current = true;
|
||||
|
||||
// Auto-select if exactly one credential matches.
|
||||
// For optional fields with multiple options, let the user choose.
|
||||
if (isOptional && savedCreds.length > 1) return;
|
||||
// Auto-select only when there is exactly one saved credential.
|
||||
// With multiple options the user must choose — regardless of optional/required.
|
||||
if (savedCreds.length > 1) return;
|
||||
|
||||
const cred = savedCreds[0];
|
||||
onSelectCredential({
|
||||
|
||||
@@ -288,7 +288,6 @@ const SidebarTrigger = React.forwardRef<
|
||||
ref={ref}
|
||||
data-sidebar="trigger"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={(event) => {
|
||||
onClick?.(event);
|
||||
toggleSidebar();
|
||||
|
||||
Reference in New Issue
Block a user