mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
77 Commits
fix/copilo
...
feat/track
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a50cf17f4 | ||
|
|
f0e595a447 | ||
|
|
291eaf3341 | ||
|
|
a8fe383481 | ||
|
|
aff3fb44af | ||
|
|
c342290910 | ||
|
|
7ff8bc8c5e | ||
|
|
b547c734cb | ||
|
|
0a6d708411 | ||
|
|
b031cc55ec | ||
|
|
ca5a0e619e | ||
|
|
fd75e14eb8 | ||
|
|
1476a580b2 | ||
|
|
a6af72ed51 | ||
|
|
dc3fa69f84 | ||
|
|
8682ef4b71 | ||
|
|
a3b7bf4424 | ||
|
|
80670345af | ||
|
|
3fc10e9fc0 | ||
|
|
e71979e0b2 | ||
|
|
1456659dba | ||
|
|
acb0d41a35 | ||
|
|
4a08482aa6 | ||
|
|
20e410bd5a | ||
|
|
8e6d140580 | ||
|
|
d6cec14543 | ||
|
|
abd9dac288 | ||
|
|
6cc2c7a50d | ||
|
|
a0d534f24b | ||
|
|
b9be577904 | ||
|
|
b9951a3c53 | ||
|
|
6fb3b1e87b | ||
|
|
760d4eeaf4 | ||
|
|
cb7d271472 | ||
|
|
7ef530c672 | ||
|
|
699ecc8cec | ||
|
|
a662dc1680 | ||
|
|
211be3aff1 | ||
|
|
3120981e4b | ||
|
|
73e1a5e76d | ||
|
|
5966d3669d | ||
|
|
d0f9ec55e4 | ||
|
|
c81ab1fc3b | ||
|
|
36606b206a | ||
|
|
a8afaffd4c | ||
|
|
a982fb8436 | ||
|
|
afcce75aff | ||
|
|
70dfe64c6d | ||
|
|
5446c7f18f | ||
|
|
2b0c9ba703 | ||
|
|
195c7011ae | ||
|
|
d4944fb22b | ||
|
|
a5ed8fefa9 | ||
|
|
a52a777b29 | ||
|
|
8bec7a6933 | ||
|
|
e73791efed | ||
|
|
2d161ce2b9 | ||
|
|
6fc4989654 | ||
|
|
976443bf6e | ||
|
|
4ceb15b3f1 | ||
|
|
3096f94996 | ||
|
|
6f90729612 | ||
|
|
ebf89dde8b | ||
|
|
5d057e97e5 | ||
|
|
1d2f641a26 | ||
|
|
dcb71ab0b9 | ||
|
|
8136b90860 | ||
|
|
4d179a7c37 | ||
|
|
f78adcdc65 | ||
|
|
40388b7520 | ||
|
|
dd7be1158b | ||
|
|
c0e59f0a6b | ||
|
|
104d1f1bf4 | ||
|
|
d9e9cd4c98 | ||
|
|
ca416300ec | ||
|
|
c589cd0c43 | ||
|
|
b6d863fcd2 |
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -5,12 +5,14 @@ on:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
merge_group:
|
||||
|
||||
169
.github/workflows/platform-frontend-ci.yml
vendored
169
.github/workflows/platform-frontend-ci.yml
vendored
@@ -120,175 +120,6 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
312
.github/workflows/platform-fullstack-ci.yml
vendored
312
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,14 +1,18 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
name: AutoGPT Platform - Full-stack CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
|
||||
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
@@ -24,42 +28,28 @@ defaults:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies to populate cache
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: big-boi
|
||||
check-api-types:
|
||||
name: check API types
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -67,70 +57,256 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
# ------------------------ Backend setup ------------------------
|
||||
|
||||
- name: Set up Backend - Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Backend - Install Poetry
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Installing Poetry version ${POETRY_VERSION}"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
|
||||
- name: Set up Backend - Set up dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Set up Backend - Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Set up Backend - Generate Prisma client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- name: Set up Frontend - Export OpenAPI schema from Backend
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||
|
||||
# ------------------------ Frontend setup ------------------------
|
||||
|
||||
- name: Set up Frontend - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up Frontend - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Set up Frontend - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
- name: Set up Frontend - Format OpenAPI schema
|
||||
id: format-schema
|
||||
run: pnpm prettier --write ./src/app/api/openapi.json
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "The openapi.json file has been modified after exporting the API schema."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo "\nIn the backend directory:"
|
||||
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
|
||||
echo "\nIn the frontend directory:"
|
||||
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
|
||||
echo "3. Run 'pnpm generate:api'"
|
||||
echo "4. Run 'pnpm types'"
|
||||
echo "5. Fix any TypeScript errors that may have been introduced"
|
||||
echo "6. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
- name: Set up Frontend - Generate API client
|
||||
id: generate-api-client
|
||||
run: pnpm orval --config ./orval.config.ts
|
||||
# Continue with type generation & check even if there are schema changes
|
||||
if: success() || (steps.format-schema.outcome == 'success')
|
||||
|
||||
- name: Check for TypeScript errors
|
||||
run: pnpm types
|
||||
if: success() || (steps.generate-api-client.outcome == 'success')
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v4
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
|
||||
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||
from fastapi import APIRouter, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from prisma.models import UserWorkspaceFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -27,6 +27,12 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||
from backend.copilot.tools.models import (
|
||||
@@ -120,6 +126,8 @@ class SessionDetailResponse(BaseModel):
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -207,7 +215,7 @@ async def list_sessions(
|
||||
}
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||
"Failed to fetch processing status from Redis; defaulting to empty"
|
||||
)
|
||||
|
||||
return ListSessionsResponse(
|
||||
@@ -229,7 +237,7 @@ async def list_sessions(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -348,7 +356,7 @@ async def update_session_title_route(
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
@@ -389,6 +397,10 @@ async def get_session(
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
# Sum token usage from session
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
@@ -396,6 +408,25 @@ async def get_session(
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/usage",
|
||||
)
|
||||
async def get_copilot_usage(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get CoPilot usage status for the authenticated user.
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
"""
|
||||
return await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -405,7 +436,7 @@ async def get_session(
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> CancelSessionResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
@@ -450,7 +481,7 @@ async def cancel_session_task(
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
@@ -467,7 +498,7 @@ async def stream_chat_post(
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
user_id: Authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
@@ -476,9 +507,7 @@ async def stream_chat_post(
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
@@ -496,6 +525,18 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Pre-turn rate limit check (token-based).
|
||||
# check_rate_limit short-circuits internally when both limits are 0.
|
||||
if user_id:
|
||||
try:
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=config.daily_token_limit,
|
||||
weekly_token_limit=config.weekly_token_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Enrich message with file metadata if file_ids are provided.
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
@@ -730,7 +771,7 @@ async def stream_chat_post(
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Resume an active stream for a session.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import fastapi
|
||||
@@ -251,6 +252,156 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
# Ensure the rate-limit branch is entered by setting a non-zero limit.
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
resets_at = datetime.now(UTC) + timedelta(days=3)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("weekly", resets_at),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"].lower()
|
||||
assert "weekly" in detail
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded(
|
||||
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/stream",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
detail = response.json()["detail"]
|
||||
assert "2h" in detail
|
||||
assert "Resets in" in detail
|
||||
|
||||
|
||||
# ─── Usage endpoint ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_usage(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
new_callable=AsyncMock,
|
||||
return_value=status,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_returns_daily_and_weekly(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""GET /usage returns daily and weekly usage."""
|
||||
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["daily"]["used"] == 500
|
||||
assert data["weekly"]["used"] == 2000
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
|
||||
response = client.get("/usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
)
|
||||
|
||||
|
||||
def test_usage_rejects_unauthenticated_request() -> None:
|
||||
"""GET /usage should return 401 when no valid JWT is provided."""
|
||||
unauthenticated_app = fastapi.FastAPI()
|
||||
unauthenticated_app.include_router(chat_routes.router)
|
||||
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
|
||||
|
||||
response = unauthenticated_client.get("/usage")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -36,13 +36,15 @@ from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from backend.copilot.service import (
|
||||
_build_system_prompt,
|
||||
_generate_session_title,
|
||||
client,
|
||||
_get_openai_client,
|
||||
config,
|
||||
)
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -89,7 +91,7 @@ async def _compress_session_messages(
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
client=_get_openai_client(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
@@ -221,6 +223,10 @@ async def stream_chat_completion_baseline(
|
||||
text_block_id = str(uuid.uuid4())
|
||||
text_started = False
|
||||
step_open = False
|
||||
# Token usage accumulators — populated from streaming chunks
|
||||
turn_prompt_tokens = 0
|
||||
turn_completion_tokens = 0
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
try:
|
||||
for _round in range(_MAX_TOOL_ROUNDS):
|
||||
# Open a new step for each LLM round
|
||||
@@ -232,16 +238,31 @@ async def stream_chat_completion_baseline(
|
||||
model=config.model,
|
||||
messages=openai_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
if tools:
|
||||
create_kwargs["tools"] = tools
|
||||
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
# Accumulate streamed response (text + tool calls)
|
||||
round_text = ""
|
||||
tool_calls_by_index: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in response:
|
||||
# Capture token usage from the streaming chunk.
|
||||
# OpenRouter normalises all providers into OpenAI format
|
||||
# where prompt_tokens already includes cached tokens
|
||||
# (unlike Anthropic's native API). Use += to sum all
|
||||
# tool-call rounds since each API call is independent.
|
||||
# NOTE: stream_options={"include_usage": True} is not
|
||||
# universally supported — some providers (Mistral, Llama
|
||||
# via OpenRouter) always return chunk.usage=None. When
|
||||
# that happens, tokens stay 0 and the tiktoken fallback
|
||||
# below activates. Fail-open: one round is estimated.
|
||||
if chunk.usage:
|
||||
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
continue
|
||||
@@ -394,6 +415,7 @@ async def stream_chat_completion_baseline(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Close any open text/step before emitting error
|
||||
@@ -411,6 +433,49 @@ async def stream_chat_completion_baseline(
|
||||
except Exception:
|
||||
logger.warning("[Baseline] Langfuse trace context teardown failed")
|
||||
|
||||
# Fallback: estimate tokens via tiktoken when the provider does
|
||||
# not honour stream_options={"include_usage": True}.
|
||||
# Count the full message list (system + history + turn) since
|
||||
# each API call sends the complete context window.
|
||||
# NOTE: This estimates one round's prompt tokens. Multi-round tool-calling
|
||||
# turns consume prompt tokens on each API call, so the total is underestimated.
|
||||
# Skip fallback when an error occurred and no output was produced —
|
||||
# charging rate-limit tokens for completely failed requests is unfair.
|
||||
if (
|
||||
turn_prompt_tokens == 0
|
||||
and turn_completion_tokens == 0
|
||||
and not (_stream_error and not assistant_text)
|
||||
):
|
||||
from backend.util.prompt import (
|
||||
estimate_token_count,
|
||||
estimate_token_count_str,
|
||||
)
|
||||
|
||||
turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
)
|
||||
turn_completion_tokens = estimate_token_count_str(
|
||||
assistant_text, model=config.model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
"prompt=%d, completion=%d",
|
||||
turn_prompt_tokens,
|
||||
turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
|
||||
# cannot break out cache_read/cache_creation weights. Users on the
|
||||
# baseline path may be slightly over-counted vs the SDK path.
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
|
||||
# Persist assistant response
|
||||
if assistant_text:
|
||||
session.messages.append(
|
||||
@@ -421,4 +486,16 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
|
||||
yield StreamUsage(
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
total_tokens=turn_prompt_tokens + turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -70,6 +70,27 @@ class ChatConfig(BaseSettings):
|
||||
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
|
||||
)
|
||||
|
||||
# Rate limiting — token-based limits per day and per week.
|
||||
# Per-turn token cost varies with context size: ~10-15K for early turns,
|
||||
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
|
||||
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
)
|
||||
weekly_token_limit: int = Field(
|
||||
default=12_500_000,
|
||||
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
@@ -115,7 +136,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
|
||||
@@ -73,6 +73,9 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
|
||||
cache_read_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
@@ -98,7 +101,10 @@ class ChatSessionInfo(BaseModel):
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
# Calculate usage from token counts.
|
||||
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
|
||||
# is lost after persistence — the DB only stores aggregate prompt and
|
||||
# completion totals. This is a known limitation.
|
||||
usage = []
|
||||
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
|
||||
usage.append(
|
||||
|
||||
266
autogpt_platform/backend/backend/copilot/rate_limit.py
Normal file
266
autogpt_platform/backend/backend/copilot/rate_limit.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""CoPilot rate limiting based on token usage.
|
||||
|
||||
Uses Redis fixed-window counters to track per-user token consumption
|
||||
with configurable daily and weekly limits. Daily windows reset at
|
||||
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
|
||||
UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key prefixes
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
used: int
|
||||
limit: int = Field(
|
||||
description="Maximum tokens allowed in this window. 0 means unlimited."
|
||||
)
|
||||
resets_at: datetime
|
||||
|
||||
|
||||
class CoPilotUsageStatus(BaseModel):
|
||||
"""Current usage status for a user across all windows."""
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds their CoPilot usage limit."""
|
||||
|
||||
def __init__(self, window: str, resets_at: datetime):
|
||||
self.window = window
|
||||
self.resets_at = resets_at
|
||||
delta = resets_at - datetime.now(UTC)
|
||||
total_secs = delta.total_seconds()
|
||||
if total_secs <= 0:
|
||||
time_str = "now"
|
||||
else:
|
||||
hours = int(total_secs // 3600)
|
||||
minutes = int((total_secs % 3600) // 60)
|
||||
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
|
||||
super().__init__(
|
||||
f"You've reached your {window} usage limit. Resets in {time_str}."
|
||||
)
|
||||
|
||||
|
||||
async def get_usage_status(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
daily_used = 0
|
||||
weekly_used = 0
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
daily_used = int(daily_raw or 0)
|
||||
weekly_used = int(weekly_raw or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for usage status, returning zeros")
|
||||
|
||||
return CoPilotUsageStatus(
|
||||
daily=UsageWindow(
|
||||
used=daily_used,
|
||||
limit=daily_token_limit,
|
||||
resets_at=_daily_reset_time(now=now),
|
||||
),
|
||||
weekly=UsageWindow(
|
||||
used=weekly_used,
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def check_rate_limit(
|
||||
user_id: str,
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
) -> None:
|
||||
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
|
||||
|
||||
This is a pre-turn soft check. The authoritative usage counter is updated
|
||||
by ``record_token_usage()`` after the turn completes. Under concurrency,
|
||||
two parallel turns may both pass this check against the same snapshot.
|
||||
This is acceptable because token-based limits are approximate by nature
|
||||
(the exact token count is unknown until after generation).
|
||||
|
||||
Fails open: if Redis is unavailable, allows the request.
|
||||
"""
|
||||
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
|
||||
# round-trip entirely.
|
||||
if daily_token_limit <= 0 and weekly_token_limit <= 0:
|
||||
return
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
daily_raw, weekly_raw = await asyncio.gather(
|
||||
redis.get(_daily_key(user_id, now=now)),
|
||||
redis.get(_weekly_key(user_id, now=now)),
|
||||
)
|
||||
daily_used = int(daily_raw or 0)
|
||||
weekly_used = int(weekly_raw or 0)
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning("Redis unavailable for rate limit check, allowing request")
|
||||
return
|
||||
|
||||
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
|
||||
if daily_token_limit > 0 and daily_used >= daily_token_limit:
|
||||
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
|
||||
|
||||
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
|
||||
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
user_id: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
Uses cost-weighted counting so cached tokens don't unfairly penalise
|
||||
multi-turn conversations. Anthropic's pricing:
|
||||
- uncached input: 100%
|
||||
- cache creation: 25%
|
||||
- cache read: 10%
|
||||
- output: 100%
|
||||
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
weighted_input = (
|
||||
prompt_tokens
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
raw_total = (
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
completion_tokens,
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
# transaction=False: these are independent INCRBY+EXPIRE pairs on
|
||||
# separate keys — no cross-key atomicity needed. Skipping
|
||||
# MULTI/EXEC avoids the overhead. If the connection drops between
|
||||
# INCRBY and EXPIRE the key survives until the next date-based key
|
||||
# rotation (daily/weekly), so the memory-leak risk is negligible.
|
||||
pipe = redis.pipeline(transaction=False)
|
||||
|
||||
# Daily counter (expires at next midnight UTC)
|
||||
d_key = _daily_key(user_id, now=now)
|
||||
pipe.incrby(d_key, total)
|
||||
seconds_until_daily_reset = int(
|
||||
(_daily_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
|
||||
|
||||
# Weekly counter (expires end of week)
|
||||
w_key = _weekly_key(user_id, now=now)
|
||||
pipe.incrby(w_key, total)
|
||||
seconds_until_weekly_reset = int(
|
||||
(_weekly_reset_time(now=now) - now).total_seconds()
|
||||
)
|
||||
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
|
||||
|
||||
await pipe.execute()
|
||||
except (RedisError, ConnectionError, OSError):
|
||||
logger.warning(
|
||||
"Redis unavailable for recording token usage (tokens=%d)",
|
||||
total,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _daily_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return f"{_USAGE_KEY_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
|
||||
|
||||
|
||||
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
year, week, _ = now.isocalendar()
|
||||
return f"{_USAGE_KEY_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
|
||||
|
||||
|
||||
def _daily_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current daily window resets (next midnight UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
|
||||
def _weekly_reset_time(now: datetime | None = None) -> datetime:
|
||||
"""Calculate when the current weekly window resets (next Monday 00:00 UTC)."""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
days_until_monday = (7 - now.weekday()) % 7 or 7
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
|
||||
days=days_until_monday
|
||||
)
|
||||
334
autogpt_platform/backend/backend/copilot/rate_limit_test.py
Normal file
334
autogpt_platform/backend/backend/copilot/rate_limit_test.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Unit tests for CoPilot rate limiting."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
check_rate_limit,
|
||||
get_usage_status,
|
||||
record_token_usage,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitExceeded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
def test_message_contains_window_name(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
|
||||
assert "daily" in str(exc)
|
||||
|
||||
def test_message_contains_reset_time(self):
|
||||
exc = RateLimitExceeded(
|
||||
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
|
||||
)
|
||||
msg = str(exc)
|
||||
# Allow for slight timing drift (29m or 30m)
|
||||
assert "2h " in msg
|
||||
assert "Resets in" in msg
|
||||
|
||||
def test_message_minutes_only_when_under_one_hour(self):
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
|
||||
msg = str(exc)
|
||||
assert "Resets in" in msg
|
||||
# Should not have "0h"
|
||||
assert "0h" not in msg
|
||||
|
||||
def test_message_says_now_when_resets_at_is_in_the_past(self):
|
||||
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
|
||||
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
|
||||
assert "Resets in now" in str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_usage_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUsageStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_redis_values(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert isinstance(status, CoPilotUsageStatus)
|
||||
assert status.daily.used == 500
|
||||
assert status.daily.limit == 10000
|
||||
assert status.weekly.used == 2000
|
||||
assert status.weekly.limit == 50000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zeros_when_redis_unavailable(self):
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_daily_counter(self):
|
||||
"""Daily counter is None (new day), weekly has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 0
|
||||
assert status.weekly.used == 3000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_none_weekly_counter(self):
|
||||
"""Weekly counter is None (start of week), daily has usage."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["500", None])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
assert status.daily.used == 500
|
||||
assert status.weekly.used == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_at_daily_is_next_midnight_utc(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["0", "0"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
status = await get_usage_status(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
# Daily reset should be within 24h
|
||||
assert status.daily.resets_at > now
|
||||
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRateLimit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_under_limit(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_daily_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_weekly_limit_exceeded(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_when_redis_unavailable(self):
|
||||
"""Fail-open: allow requests when Redis is down."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=10000, weekly_token_limit=50000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_check_when_limit_is_zero(self):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — limits of 0 mean unlimited
|
||||
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# record_token_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
@staticmethod
|
||||
def _make_pipeline_mock() -> MagicMock:
|
||||
"""Create a pipeline mock with sync methods and async execute."""
|
||||
pipe = MagicMock()
|
||||
pipe.execute = AsyncMock(return_value=[])
|
||||
return pipe
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increments_redis_counters(self):
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
# Should call incrby twice (daily + weekly) with total=150
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 150 # daily
|
||||
assert incrby_calls[1].args[1] == 150 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_zero_tokens(self):
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
# Should not call pipeline at all
|
||||
mock_redis.pipeline.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_expire_on_both_keys(self):
|
||||
"""Pipeline should call expire for both daily and weekly keys."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
expire_calls = mock_pipe.expire.call_args_list
|
||||
assert len(expire_calls) == 2
|
||||
|
||||
# Daily key TTL should be positive (seconds until next midnight)
|
||||
daily_ttl = expire_calls[0].args[1]
|
||||
assert daily_ttl >= 1
|
||||
|
||||
# Weekly key TTL should be positive (seconds until next Monday)
|
||||
weekly_ttl = expire_calls[1].args[1]
|
||||
assert weekly_ttl >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_failure_gracefully(self):
|
||||
"""Should not raise when Redis is unavailable."""
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
side_effect=ConnectionError("Redis down"),
|
||||
):
|
||||
# Should not raise
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_weighted_counting(self):
|
||||
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
await record_token_usage(
|
||||
_USER,
|
||||
prompt_tokens=100, # uncached → 100
|
||||
completion_tokens=50, # output → 50
|
||||
cache_read_tokens=10000, # 10% → 1000
|
||||
cache_creation_tokens=400, # 25% → 100
|
||||
)
|
||||
|
||||
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
|
||||
incrby_calls = mock_pipe.incrby.call_args_list
|
||||
assert len(incrby_calls) == 2
|
||||
assert incrby_calls[0].args[1] == 1250 # daily
|
||||
assert incrby_calls[1].args[1] == 1250 # weekly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_redis_error_during_pipeline_execute(self):
|
||||
"""Should not raise when pipeline.execute() fails with RedisError."""
|
||||
mock_pipe = self._make_pipeline_mock()
|
||||
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pipeline = lambda **_kw: mock_pipe
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
# Should not raise — fail-open
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
@@ -43,7 +43,6 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -187,12 +186,43 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
"""Token usage statistics.
|
||||
|
||||
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
|
||||
(it uses z.strictObject() and rejects unknown event types).
|
||||
Usage data is recorded server-side (session DB + Redis counters).
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
prompt_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="promptTokens",
|
||||
description="Number of uncached prompt tokens",
|
||||
)
|
||||
completion_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="completionTokens",
|
||||
description="Number of completion tokens",
|
||||
)
|
||||
total_tokens: int = Field(
|
||||
...,
|
||||
serialization_alias="totalTokens",
|
||||
description="Total number of tokens (raw, not weighted)",
|
||||
)
|
||||
cache_read_tokens: int = Field(
|
||||
default=0,
|
||||
serialization_alias="cacheReadTokens",
|
||||
description="Prompt tokens served from cache (10% cost)",
|
||||
)
|
||||
cache_creation_tokens: int = Field(
|
||||
default=0,
|
||||
serialization_alias="cacheCreationTokens",
|
||||
description="Prompt tokens written to cache (25% cost)",
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Emit as SSE comment so the AI SDK parser ignores it."""
|
||||
return f": usage {self.model_dump_json(exclude_none=True, by_alias=True)}\n\n"
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
@@ -233,26 +263,3 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Encode as an SSE comment so the AI SDK stream parser ignores it.
|
||||
|
||||
The frontend AI SDK validates every ``data:`` line against a strict
|
||||
Zod union of known chunk types. ``"status"`` is not in that union,
|
||||
so sending it as ``data:`` would cause a schema-validation error that
|
||||
breaks the entire stream. Using an SSE comment (``:``) keeps the
|
||||
connection alive and is silently discarded by ``EventSource`` parsers.
|
||||
"""
|
||||
return f": status {self.message}\n\n"
|
||||
|
||||
@@ -12,7 +12,6 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -120,12 +119,14 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -221,7 +222,6 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -1,552 +0,0 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
@@ -226,7 +226,7 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
@@ -89,9 +89,7 @@ def _validate_tool_access(
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
"Blocked dangerous pattern in tool input: %s in %s",
|
||||
pattern,
|
||||
tool_name,
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return _deny(
|
||||
"[SECURITY] Input contains a blocked pattern. "
|
||||
@@ -113,9 +111,7 @@ def _validate_user_isolation(
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(
|
||||
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||
)
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
@@ -174,7 +170,7 @@ def create_security_hooks(
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
@@ -185,9 +181,7 @@ def create_security_hooks(
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
"[SDK] Task limit reached (%d), user=%s",
|
||||
max_subtasks,
|
||||
user_id,
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
@@ -218,7 +212,7 @@ def create_security_hooks(
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
@@ -288,11 +282,8 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
str(error).replace("\n", "").replace("\r", ""),
|
||||
user_id,
|
||||
tool_use_id,
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
@@ -310,19 +301,16 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input before logging to prevent log injection
|
||||
trigger = (
|
||||
str(input_data.get("trigger", "auto"))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,283 +0,0 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
and the ``ReducedContext`` named tuple.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
|
||||
@@ -10,9 +10,6 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -20,12 +17,8 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,14 +99,7 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
# seen_parents is local per-entry (not shared across iterations) so
|
||||
# it can only detect cycles within a single ancestry walk, not across
|
||||
# entries. This is intentional: each entry's parent chain is
|
||||
# independent, and reusing a global set would incorrectly short-circuit
|
||||
# valid re-use of the same UUID as a parent in different subtrees.
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
while parent in stripped_uuids:
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -341,7 +327,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -351,17 +337,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -422,6 +408,8 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
@@ -460,15 +448,17 @@ async def upload_transcript(
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -504,14 +494,11 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
)
|
||||
|
||||
|
||||
@@ -525,6 +512,8 @@ async def download_transcript(
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -532,10 +521,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -547,14 +536,10 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -568,6 +553,8 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -584,280 +571,3 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
|
||||
Like ``_flatten_assistant_content``, structured blocks (images, nested
|
||||
tool results) are reduced to text representations for compression.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
sub_type = sub.get("type")
|
||||
if sub_type in ("image", "document"):
|
||||
# Avoid serializing base64 binary data into
|
||||
# the compaction input — use a placeholder.
|
||||
str_parts.append(f"[__{sub_type}__]")
|
||||
elif sub_type == "text" or sub.get("text") is not None:
|
||||
str_parts.append(str(sub.get("text", "")))
|
||||
else:
|
||||
str_parts.append(json.dumps(sub))
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, dict):
|
||||
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
btype = block.get("type", "unknown")
|
||||
str_parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to plain message dicts for compression.
|
||||
|
||||
Parses each line of the JSONL *content*, skips strippable metadata entries
|
||||
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
|
||||
flattened ``content`` from the ``message`` field of each remaining entry.
|
||||
|
||||
Structured content blocks (``tool_use``, ``tool_result``, images) are
|
||||
flattened to plain text via ``_flatten_assistant_content`` and
|
||||
``_flatten_tool_result_content`` so that ``compress_context`` can
|
||||
perform token counting and LLM summarization on uniform strings.
|
||||
|
||||
Returns:
|
||||
A list of ``{"role": str, "content": str}`` dicts suitable for
|
||||
``compress_context``.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format.
|
||||
|
||||
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
|
||||
dicts returned by ``compress_context``. Each message becomes one JSONL
|
||||
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
|
||||
``--resume`` flag can reconstruct a valid conversation tree.
|
||||
|
||||
Assistant messages are wrapped in the full ``message`` envelope
|
||||
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
|
||||
that the CLI expects. User messages use the simpler ``{role, content}``
|
||||
form.
|
||||
|
||||
Returns:
|
||||
A newline-terminated JSONL string, or an empty string if *messages*
|
||||
is empty.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str = "" # root entry uses empty string, not null
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
_COMPACTION_TIMEOUT_SECONDS = 60
|
||||
_TRUNCATION_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
Uses the shared OpenAI client from ``get_openai_client()``.
|
||||
If no client is configured or the LLM call fails, falls back to
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
transcripts.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
model: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ class TranscriptBuilder:
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -382,7 +382,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -402,7 +402,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -420,7 +420,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -897,134 +897,3 @@ class TestCompactionFlowIntegration:
|
||||
output2 = builder2.to_jsonl()
|
||||
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
|
||||
assert lines2[-1]["parentUuid"] == "a2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_compression (direct tests for the 3 code paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCompression:
|
||||
"""Direct tests for ``_run_compression`` covering all 3 code paths.
|
||||
|
||||
Paths:
|
||||
(a) No OpenAI client configured → truncation fallback immediately.
|
||||
(b) LLM success → returns LLM-compressed result.
|
||||
(c) LLM call raises → truncation fallback.
|
||||
"""
|
||||
|
||||
def _make_compress_result(self, was_compacted: bool, msgs=None):
|
||||
"""Build a minimal CompressResult-like object."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
was_compacted=was_compacted,
|
||||
messages=msgs or [{"role": "user", "content": "summary"}],
|
||||
original_token_count=500,
|
||||
token_count=100 if was_compacted else 500,
|
||||
messages_summarized=2 if was_compacted else 0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with client=None (truncation mode)
|
||||
call_kwargs = mock_compress.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("client") is None
|
||||
or (call_kwargs.args and call_kwargs.args[2] is None)
|
||||
or mock_compress.call_args[1].get("client") is None
|
||||
)
|
||||
assert result is truncation_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with the real client
|
||||
assert mock_compress.called
|
||||
assert result is llm_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(**kwargs):
|
||||
call_count[0] += 1
|
||||
if kwargs.get("client") is not None:
|
||||
raise RuntimeError("LLM timeout")
|
||||
return truncation_result
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called twice: once for LLM (raises), once for truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
@@ -28,10 +28,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
_client: LangfuseAsyncOpenAI | None = None
|
||||
_langfuse = None
|
||||
|
||||
|
||||
langfuse = get_client()
|
||||
def _get_openai_client() -> LangfuseAsyncOpenAI:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
return _client
|
||||
|
||||
|
||||
def _get_langfuse():
|
||||
global _langfuse
|
||||
if _langfuse is None:
|
||||
_langfuse = get_client()
|
||||
return _langfuse
|
||||
|
||||
|
||||
# Default system prompt used when Langfuse is not configured
|
||||
# Provides minimal baseline tone and personality - all workflow, tools, and
|
||||
@@ -84,7 +98,7 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
else "latest"
|
||||
)
|
||||
prompt = await asyncio.to_thread(
|
||||
langfuse.get_prompt,
|
||||
_get_langfuse().get_prompt,
|
||||
config.langfuse_prompt_name,
|
||||
label=label,
|
||||
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
|
||||
@@ -158,7 +172,7 @@ async def _generate_session_title(
|
||||
"environment": settings.config.app_env.value,
|
||||
}
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
response = await _get_openai_client().chat.completions.create(
|
||||
model=config.title_model,
|
||||
messages=[
|
||||
{
|
||||
|
||||
93
autogpt_platform/backend/backend/copilot/token_tracking.py
Normal file
93
autogpt_platform/backend/backend/copilot/token_tracking.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Shared token-usage persistence and rate-limit recording.
|
||||
|
||||
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def persist_and_record_usage(
|
||||
*,
|
||||
session: ChatSession | None,
|
||||
user_id: str | None,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
log_prefix: str = "",
|
||||
cost_usd: float | str | None = None,
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
Args:
|
||||
session: The chat session to append usage to (may be None on error).
|
||||
user_id: User ID for rate-limit counters (skipped if None).
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
|
||||
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
if prompt_tokens <= 0 and completion_tokens <= 0:
|
||||
return 0
|
||||
|
||||
# total_tokens = prompt + completion. Cache tokens are tracked
|
||||
# separately and excluded from total so both baseline and SDK
|
||||
# paths share the same semantics.
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if session is not None:
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
if cache_read_tokens or cache_creation_tokens:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
|
||||
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
|
||||
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
|
||||
f"completion={completion_tokens}, total={total_tokens}"
|
||||
)
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
await record_token_usage(
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
|
||||
return total_tokens
|
||||
281
autogpt_platform/backend/backend/copilot/token_tracking_test.py
Normal file
281
autogpt_platform/backend/backend/copilot/token_tracking_test.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Unit tests for token_tracking.persist_and_record_usage.
|
||||
|
||||
Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
|
||||
calling conventions, session persistence, and rate-limit recording.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .token_tracking import persist_and_record_usage
|
||||
|
||||
|
||||
def _make_session() -> ChatSession:
|
||||
"""Return a minimal in-memory ChatSession for testing."""
|
||||
return ChatSession(
|
||||
session_id="sess-test",
|
||||
user_id="user-test",
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Return value / total_tokens semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTotalTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_prompt_plus_completion(self):
|
||||
"""total_tokens = prompt + completion (cache excluded from total)."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=300,
|
||||
completion_tokens=200,
|
||||
)
|
||||
assert total == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_zero_when_no_tokens(self):
|
||||
"""Returns 0 early when both prompt and completion are zero."""
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_excluded_from_total(self):
|
||||
"""Cache tokens are stored separately and not added to total_tokens."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=5000,
|
||||
cache_creation_tokens=200,
|
||||
)
|
||||
# total = prompt + completion only (5000 + 200 cache excluded)
|
||||
assert total == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_path_no_cache(self):
|
||||
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="u1",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=400,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
assert total == 1400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_path_with_cache(self):
|
||||
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="u2",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=8000,
|
||||
cache_creation_tokens=400,
|
||||
log_prefix="[SDK]",
|
||||
cost_usd=0.0015,
|
||||
)
|
||||
assert total == 300
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_appends_usage_to_session(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert len(session.usage) == 1
|
||||
usage: Usage = session.usage[0]
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.cache_read_tokens == 0
|
||||
assert usage.cache_creation_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_appends_cache_breakdown_to_session(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=None,
|
||||
prompt_tokens=200,
|
||||
completion_tokens=80,
|
||||
cache_read_tokens=3000,
|
||||
cache_creation_tokens=500,
|
||||
)
|
||||
usage: Usage = session.usage[0]
|
||||
assert usage.cache_read_tokens == 3000
|
||||
assert usage.cache_creation_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_turns_append_multiple_records(self):
|
||||
session = _make_session()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=session, user_id=None, prompt_tokens=100, completion_tokens=50
|
||||
)
|
||||
await persist_and_record_usage(
|
||||
session=session, user_id=None, prompt_tokens=200, completion_tokens=70
|
||||
)
|
||||
assert len(session.usage) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_session_does_not_raise(self):
|
||||
"""When session is None (e.g. error path), no exception should be raised."""
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert total == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_append_when_zero_tokens(self):
|
||||
"""When tokens are zero, function returns early — session unchanged."""
|
||||
session = _make_session()
|
||||
total = await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
assert total == 0
|
||||
assert len(session.usage) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate-limit recording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitRecording:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_record_token_usage_when_user_id_present(self):
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-abc",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
)
|
||||
mock_record.assert_awaited_once_with(
|
||||
user_id="user-abc",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_user_id_is_none(self):
|
||||
"""Anonymous sessions should not create Redis keys."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_failure_does_not_raise(self):
|
||||
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
|
||||
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
# Should not raise
|
||||
total = await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-xyz",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
assert total == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_record_when_zero_tokens(self):
|
||||
"""Returns 0 before calling record_token_usage when tokens are zero."""
|
||||
mock_record = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new=mock_record,
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-abc",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
@@ -41,7 +41,8 @@ import contextlib
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox, SandboxLifecycle
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
|
||||
@@ -8,11 +8,13 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.db_accessors import credit_db, workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
from backend.util.exceptions import BlockError, InsufficientBalanceError
|
||||
from backend.util.type import coerce_inputs_to_schema
|
||||
|
||||
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
|
||||
@@ -115,6 +117,21 @@ async def execute_block(
|
||||
# Coerce non-matching data types to the expected input schema.
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
|
||||
# Pre-execution credit check (courtesy; spend_credits is atomic)
|
||||
cost, cost_filter = block_usage_cost(block, input_data)
|
||||
has_cost = cost > 0
|
||||
_credit_db = credit_db()
|
||||
if has_cost:
|
||||
balance = await _credit_db.get_credits(user_id)
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Insufficient credits to run '{block.name}'. "
|
||||
"Please top up your credits to continue."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
@@ -123,6 +140,51 @@ async def execute_block(
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _credit_db.spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=synthetic_graph_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=block_id,
|
||||
block=block.name,
|
||||
input=cost_filter,
|
||||
reason="copilot_block_execution",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Block already executed (with possible side effects). Never
|
||||
# return ErrorResponse here — the user received output and
|
||||
# deserves it. Log the billing failure for reconciliation.
|
||||
leak_type = (
|
||||
"INSUFFICIENT_BALANCE"
|
||||
if isinstance(e, InsufficientBalanceError)
|
||||
else "UNEXPECTED_ERROR"
|
||||
)
|
||||
logger.error(
|
||||
"BILLING_LEAK[%s]: block executed but credit charge failed — "
|
||||
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
|
||||
leak_type,
|
||||
user_id,
|
||||
block_id,
|
||||
node_exec_id,
|
||||
cost,
|
||||
e,
|
||||
extra={
|
||||
"json_fields": {
|
||||
"billing_leak": True,
|
||||
"leak_type": leak_type,
|
||||
"user_id": user_id,
|
||||
"cost": str(cost),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
@@ -133,14 +195,14 @@ async def execute_block(
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
logger.warning("Block execution failed: %s", e)
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
|
||||
@@ -1,18 +1,197 @@
|
||||
"""Tests for execute_block type coercion in helpers.py.
|
||||
|
||||
Verifies that execute_block() coerces string input values to match the block's
|
||||
expected input types, mirroring the executor's validate_exec() logic.
|
||||
This is critical for @@agptfile: expansion, where file content is always a string
|
||||
but the block may expect structured types (e.g. list[list[str]]).
|
||||
"""
|
||||
"""Tests for execute_block — credit charging and type coercion."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
|
||||
_USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
|
||||
|
||||
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||
"""Create a minimal mock block for execute_block()."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
async def _execute(
|
||||
input_data: dict, **kwargs: Any
|
||||
) -> AsyncIterator[tuple[str, Any]]:
|
||||
yield "result", "ok"
|
||||
|
||||
mock.execute = _execute
|
||||
return mock
|
||||
|
||||
|
||||
def _patch_workspace():
|
||||
"""Patch workspace_db to return a mock workspace."""
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.id = "ws-1"
|
||||
mock_ws_db = MagicMock()
|
||||
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
|
||||
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||
|
||||
|
||||
def _patch_credit_db(
|
||||
get_credits_return: int = 100,
|
||||
spend_credits_side_effect: Any = None,
|
||||
):
|
||||
"""Patch credit_db accessor to return a mock credit adapter."""
|
||||
mock_credit = MagicMock()
|
||||
mock_credit.get_credits = AsyncMock(return_value=get_credits_return)
|
||||
if spend_credits_side_effect is not None:
|
||||
mock_credit.spend_credits = AsyncMock(side_effect=spend_credits_side_effect)
|
||||
else:
|
||||
mock_credit.spend_credits = AsyncMock()
|
||||
return (
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.credit_db",
|
||||
return_value=mock_credit,
|
||||
),
|
||||
mock_credit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credit charging tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestExecuteBlockCreditCharging:
|
||||
async def test_charges_credits_when_cost_is_positive(self):
|
||||
"""Block with cost > 0 should call spend_credits after execution."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(get_credits_return=100)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {"key": "val"}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={"text": "hello"},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
mock_credit.spend_credits.assert_awaited_once()
|
||||
call_kwargs = mock_credit.spend_credits.call_args.kwargs
|
||||
assert call_kwargs["cost"] == 10
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(get_credits_return=5)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "Insufficient credits" in result.message
|
||||
|
||||
async def test_no_charge_when_cost_is_zero(self):
|
||||
"""Block with cost 0 should not call spend_credits."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Credit functions should not be called at all for zero-cost blocks
|
||||
mock_credit.get_credits.assert_not_awaited()
|
||||
mock_credit.spend_credits.assert_not_awaited()
|
||||
|
||||
async def test_returns_output_on_post_exec_insufficient_balance(self):
|
||||
"""If charging fails after execution, output is still returned (block already ran)."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(
|
||||
get_credits_return=15,
|
||||
spend_credits_side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="block-1",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-1",
|
||||
matched_credentials={},
|
||||
)
|
||||
|
||||
# Block already executed (with side effects), so output is returned
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type coercion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
@@ -28,7 +207,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
|
||||
return schema
|
||||
|
||||
|
||||
def _make_block(
|
||||
def _make_coerce_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
annotations: dict[str, Any],
|
||||
@@ -60,7 +239,7 @@ _TEST_USER_ID = "test-user-coerce"
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
{"values": list[list[str]], "spreadsheet_id": str},
|
||||
@@ -103,7 +282,7 @@ async def test_coerce_json_string_to_nested_list():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
{"items": list[str]},
|
||||
@@ -135,7 +314,7 @@ async def test_coerce_json_string_to_list():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
{"config": dict[str, str]},
|
||||
@@ -167,7 +346,7 @@ async def test_coerce_json_string_to_dict():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_coercion_when_type_matches():
|
||||
"""Already-correct types pass through without coercion."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"pass-through",
|
||||
"Pass Through",
|
||||
{"values": list[list[str]], "name": str},
|
||||
@@ -201,7 +380,7 @@ async def test_no_coercion_when_type_matches():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
{"count": int},
|
||||
@@ -234,7 +413,7 @@ async def test_coerce_string_to_int():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_skips_none_values():
|
||||
"""None values are not coerced (they may be optional fields)."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"optional-block",
|
||||
"Optional Block",
|
||||
{"data": list[str], "label": str},
|
||||
@@ -267,7 +446,7 @@ async def test_coerce_skips_none_values():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_union_type_preserves_valid_member():
|
||||
"""Union-typed fields should not be coerced when the value matches a member."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"union-block",
|
||||
"Union Block",
|
||||
{"content": str | list[str]},
|
||||
@@ -301,7 +480,7 @@ async def test_coerce_union_type_preserves_valid_member():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_inner_elements_of_generic():
|
||||
"""Inner elements of generic containers are recursively coerced."""
|
||||
block = _make_block(
|
||||
block = _make_coerce_block(
|
||||
"inner-coerce",
|
||||
"Inner Coerce",
|
||||
{"values": list[str]},
|
||||
|
||||
@@ -129,3 +129,16 @@ def review_db():
|
||||
review_db = get_database_manager_async_client()
|
||||
|
||||
return review_db
|
||||
|
||||
|
||||
def credit_db():
|
||||
if db.is_connected():
|
||||
from backend.data import db_manager as _credit_db
|
||||
|
||||
credit_db = _credit_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
@@ -148,6 +148,11 @@ async def _get_credits(user_id: str) -> int:
|
||||
return await user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
# Public aliases used by db_accessors.credit_db() when Prisma is connected
|
||||
get_credits = _get_credits
|
||||
spend_credits = _spend_credits
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""Database connection pooling service.
|
||||
|
||||
@@ -512,6 +517,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
list_workspace_files = d.list_workspace_files
|
||||
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||
|
||||
# ============ Credits ============ #
|
||||
spend_credits = d.spend_credits
|
||||
get_credits = d.get_credits
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
@@ -70,10 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens (standard: "text" key, fallback: "content")
|
||||
text_val = item.get("text") or item.get("content", "")
|
||||
tool_call_tokens += _tok_len(text_val, enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -149,16 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
if max_tok < 1:
|
||||
return ""
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(ids[:max_tok])
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -555,14 +545,6 @@ async def _summarize_messages_llm(
|
||||
"- Actions taken and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"IMPORTANT: Preserve all concrete references verbatim — these are small but "
|
||||
"critical for continuing the conversation:\n"
|
||||
"- File paths and directory paths (e.g. /src/app/page.tsx, ./output/result.csv)\n"
|
||||
"- Image/media file paths from tool outputs\n"
|
||||
"- URLs, API endpoints, and webhook addresses\n"
|
||||
"- Resource IDs, session IDs, and identifiers\n"
|
||||
"- Tool names that were called and their key parameters\n"
|
||||
"- Environment variables, config keys, and credentials names (not values)\n\n"
|
||||
"Include ONLY the sections below that have relevant content "
|
||||
"(skip sections with nothing to report):\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
@@ -570,8 +552,7 @@ async def _summarize_messages_llm(
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers. "
|
||||
"Include exact file paths, image paths from tool outputs, and resource URLs.\n\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions.\n\n"
|
||||
"## 5. All User Messages\n"
|
||||
@@ -585,7 +566,7 @@ async def _summarize_messages_llm(
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=2000,
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
@@ -705,15 +686,11 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
|
||||
import { UploadSimple } from "@phosphor-icons/react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||
@@ -89,10 +83,9 @@ export function CopilotPage() {
|
||||
handleDrawerOpenChange,
|
||||
handleSelectSession,
|
||||
handleNewChat,
|
||||
// Delete functionality
|
||||
// Delete functionality (available via ChatSidebar context menu on all viewports)
|
||||
sessionToDelete,
|
||||
isDeleting,
|
||||
handleDeleteClick,
|
||||
handleConfirmDelete,
|
||||
handleCancelDelete,
|
||||
} = useCopilotPage();
|
||||
@@ -148,38 +141,6 @@ export function CopilotPage() {
|
||||
isUploadingFiles={isUploadingFiles}
|
||||
droppedFiles={droppedFiles}
|
||||
onDroppedFilesConsumed={handleDroppedFilesConsumed}
|
||||
headerSlot={
|
||||
isMobile && sessionId ? (
|
||||
<div className="flex justify-end">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1.5 hover:bg-neutral-100"
|
||||
aria-label="More actions"
|
||||
>
|
||||
<DotsThree className="h-5 w-5 text-neutral-600" />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
onClick={() => {
|
||||
const session = sessions.find(
|
||||
(s) => s.id === sessionId,
|
||||
);
|
||||
if (session) {
|
||||
handleDeleteClick(session.id, session.title);
|
||||
}
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||
>
|
||||
Delete chat
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { LayoutGroup, motion } from "framer-motion";
|
||||
import { ReactNode } from "react";
|
||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { EmptySession } from "../EmptySession/EmptySession";
|
||||
@@ -21,7 +20,6 @@ export interface ChatContainerProps {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
onStop: () => void;
|
||||
isUploadingFiles?: boolean;
|
||||
headerSlot?: ReactNode;
|
||||
/** Files dropped onto the chat window. */
|
||||
droppedFiles?: File[];
|
||||
/** Called after droppedFiles have been consumed by ChatInput. */
|
||||
@@ -40,7 +38,6 @@ export const ChatContainer = ({
|
||||
onSend,
|
||||
onStop,
|
||||
isUploadingFiles,
|
||||
headerSlot,
|
||||
droppedFiles,
|
||||
onDroppedFilesConsumed,
|
||||
}: ChatContainerProps) => {
|
||||
@@ -63,7 +60,6 @@ export const ChatContainer = ({
|
||||
status={status}
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
headerSlot={headerSlot}
|
||||
sessionID={sessionId}
|
||||
/>
|
||||
<motion.div
|
||||
|
||||
@@ -30,7 +30,6 @@ interface Props {
|
||||
status: string;
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
headerSlot?: React.ReactNode;
|
||||
sessionID?: string | null;
|
||||
}
|
||||
|
||||
@@ -102,7 +101,6 @@ export function ChatMessagesContainer({
|
||||
status,
|
||||
error,
|
||||
isLoading,
|
||||
headerSlot,
|
||||
sessionID,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
@@ -135,7 +133,6 @@ export function ChatMessagesContainer({
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
{headerSlot}
|
||||
{isLoading && messages.length === 0 && (
|
||||
<div
|
||||
className="flex flex-1 items-center justify-center"
|
||||
|
||||
@@ -37,6 +37,7 @@ import { useCopilotUIStore } from "../../store";
|
||||
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
|
||||
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
|
||||
import { PulseLoader } from "../PulseLoader/PulseLoader";
|
||||
import { UsageLimits } from "../UsageLimits/UsageLimits";
|
||||
|
||||
export function ChatSidebar() {
|
||||
const { state } = useSidebar();
|
||||
@@ -256,11 +257,10 @@ export function ChatSidebar() {
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
<div className="relative left-5 flex items-center gap-1">
|
||||
<div className="flex items-center">
|
||||
<UsageLimits />
|
||||
<NotificationToggle />
|
||||
<div className="relative left-1">
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
<SidebarTrigger />
|
||||
</div>
|
||||
</div>
|
||||
{sessionId ? (
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
|
||||
import { useCopilotUIStore } from "../../../../store";
|
||||
@@ -48,10 +49,7 @@ export function NotificationToggle() {
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
|
||||
aria-label="Notification settings"
|
||||
>
|
||||
<Button variant="ghost" size="icon" aria-label="Notification settings">
|
||||
{!isNotificationsEnabled ? (
|
||||
<BellSlash className="!size-5" />
|
||||
) : isSoundEnabled ? (
|
||||
@@ -59,7 +57,7 @@ export function NotificationToggle() {
|
||||
) : (
|
||||
<Bell className="!size-5" />
|
||||
)}
|
||||
</button>
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-56 p-3">
|
||||
<div className="flex flex-col gap-3">
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/molecules/Popover/Popover";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChartBar } from "@phosphor-icons/react";
|
||||
import { UsagePanelContent } from "./UsagePanelContent";
|
||||
|
||||
export { UsagePanelContent, formatResetTime } from "./UsagePanelContent";
|
||||
|
||||
export function UsageLimits() {
|
||||
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon" aria-label="Usage limits">
|
||||
<ChartBar className="!size-5" weight="light" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent align="start" className="w-64 p-3">
|
||||
<UsagePanelContent usage={usage} />
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import Link from "next/link";
|
||||
|
||||
export function formatResetTime(
|
||||
resetsAt: Date | string,
|
||||
now: Date = new Date(),
|
||||
): string {
|
||||
const resetDate =
|
||||
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
|
||||
const diffMs = resetDate.getTime() - now.getTime();
|
||||
if (diffMs <= 0) return "now";
|
||||
|
||||
const hours = Math.floor(diffMs / (1000 * 60 * 60));
|
||||
|
||||
// Under 24h: show relative time ("in 4h 23m")
|
||||
if (hours < 24) {
|
||||
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
if (hours > 0) return `in ${hours}h ${minutes}m`;
|
||||
return `in ${minutes}m`;
|
||||
}
|
||||
|
||||
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
|
||||
return resetDate.toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
}
|
||||
|
||||
function UsageBar({
|
||||
label,
|
||||
used,
|
||||
limit,
|
||||
resetsAt,
|
||||
}: {
|
||||
label: string;
|
||||
used: number;
|
||||
limit: number;
|
||||
resetsAt: Date | string;
|
||||
}) {
|
||||
if (limit <= 0) return null;
|
||||
|
||||
const rawPercent = (used / limit) * 100;
|
||||
const percent = Math.min(100, Math.round(rawPercent));
|
||||
const isHigh = percent >= 80;
|
||||
const percentLabel =
|
||||
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex items-baseline justify-between">
|
||||
<span className="text-xs font-medium text-neutral-700">{label}</span>
|
||||
<span className="text-[11px] tabular-nums text-neutral-500">
|
||||
{percentLabel}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-neutral-400">
|
||||
Resets {formatResetTime(resetsAt)}
|
||||
</div>
|
||||
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||
<div
|
||||
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
|
||||
isHigh ? "bg-orange-500" : "bg-blue-500"
|
||||
}`}
|
||||
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function UsagePanelContent({
|
||||
usage,
|
||||
showBillingLink = true,
|
||||
}: {
|
||||
usage: CoPilotUsageStatus;
|
||||
showBillingLink?: boolean;
|
||||
}) {
|
||||
const hasDailyLimit = usage.daily.limit > 0;
|
||||
const hasWeeklyLimit = usage.weekly.limit > 0;
|
||||
|
||||
if (!hasDailyLimit && !hasWeeklyLimit) {
|
||||
return (
|
||||
<div className="text-xs text-neutral-500">No usage limits configured</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
|
||||
{hasDailyLimit && (
|
||||
<UsageBar
|
||||
label="Today"
|
||||
used={usage.daily.used}
|
||||
limit={usage.daily.limit}
|
||||
resetsAt={usage.daily.resets_at}
|
||||
/>
|
||||
)}
|
||||
{hasWeeklyLimit && (
|
||||
<UsageBar
|
||||
label="This week"
|
||||
used={usage.weekly.used}
|
||||
limit={usage.weekly.limit}
|
||||
resetsAt={usage.weekly.resets_at}
|
||||
/>
|
||||
)}
|
||||
{showBillingLink && (
|
||||
<Link
|
||||
href="/profile/credits"
|
||||
className="text-[11px] text-blue-600 hover:underline"
|
||||
>
|
||||
Learn more about usage limits
|
||||
</Link>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UsageLimits } from "../UsageLimits";
|
||||
|
||||
// Mock the generated Orval hook
|
||||
const mockUseGetV2GetCopilotUsage = vi.fn();
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useGetV2GetCopilotUsage: (opts: unknown) => mockUseGetV2GetCopilotUsage(opts),
|
||||
}));
|
||||
|
||||
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
|
||||
vi.mock("@/components/molecules/Popover/Popover", () => ({
|
||||
Popover: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
PopoverContent: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseGetV2GetCopilotUsage.mockReset();
|
||||
});
|
||||
|
||||
function makeUsage({
|
||||
dailyUsed = 500,
|
||||
dailyLimit = 10000,
|
||||
weeklyUsed = 2000,
|
||||
weeklyLimit = 50000,
|
||||
}: {
|
||||
dailyUsed?: number;
|
||||
dailyLimit?: number;
|
||||
weeklyUsed?: number;
|
||||
weeklyLimit?: number;
|
||||
} = {}) {
|
||||
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
|
||||
return {
|
||||
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
|
||||
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
|
||||
};
|
||||
}
|
||||
|
||||
describe("UsageLimits", () => {
|
||||
it("renders nothing while loading", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: true,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders nothing when no limits are configured", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
|
||||
isLoading: false,
|
||||
});
|
||||
const { container } = render(<UsageLimits />);
|
||||
expect(container.innerHTML).toBe("");
|
||||
});
|
||||
|
||||
it("renders the usage button when limits exist", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
|
||||
});
|
||||
|
||||
it("displays daily and weekly usage percentages", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("50% used")).toBeDefined();
|
||||
expect(screen.getByText("Today")).toBeDefined();
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.getByText("Usage limits")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows only weekly bar when daily limit is 0", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({
|
||||
dailyLimit: 0,
|
||||
weeklyUsed: 25000,
|
||||
weeklyLimit: 50000,
|
||||
}),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("This week")).toBeDefined();
|
||||
expect(screen.queryByText("Today")).toBeNull();
|
||||
});
|
||||
|
||||
it("caps percentage at 100% when over limit", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
expect(screen.getByText("100% used")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows learn more link to credits page", () => {
|
||||
mockUseGetV2GetCopilotUsage.mockReturnValue({
|
||||
data: makeUsage(),
|
||||
isLoading: false,
|
||||
});
|
||||
render(<UsageLimits />);
|
||||
|
||||
const link = screen.getByText("Learn more about usage limits");
|
||||
expect(link).toBeDefined();
|
||||
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
getGetV2GetCopilotUsageQueryKey,
|
||||
getGetV2GetSessionQueryKey,
|
||||
postV2CancelSessionTask,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
@@ -177,12 +178,41 @@ export function useCopilotStream({
|
||||
onError: (error) => {
|
||||
if (!sessionId) return;
|
||||
|
||||
// Detect rate limit (429) responses and show reset time to the user.
|
||||
// The SDK throws a plain Error whose message is the raw response body
|
||||
// (FastAPI returns {"detail": "...usage limit..."} for 429s).
|
||||
let errorDetail: string = error.message;
|
||||
try {
|
||||
const parsed = JSON.parse(error.message) as unknown;
|
||||
if (
|
||||
typeof parsed === "object" &&
|
||||
parsed !== null &&
|
||||
"detail" in parsed &&
|
||||
typeof (parsed as { detail: unknown }).detail === "string"
|
||||
) {
|
||||
errorDetail = (parsed as { detail: string }).detail;
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — use message as-is
|
||||
}
|
||||
const isRateLimited = errorDetail.toLowerCase().includes("usage limit");
|
||||
if (isRateLimited) {
|
||||
toast({
|
||||
title: "Usage limit reached",
|
||||
description:
|
||||
errorDetail ||
|
||||
"You've reached your usage limit. Please try again later.",
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Detect authentication failures (from getAuthHeaders or 401 responses)
|
||||
const isAuthError =
|
||||
error.message.includes("Authentication failed") ||
|
||||
error.message.includes("Unauthorized") ||
|
||||
error.message.includes("Not authenticated") ||
|
||||
error.message.toLowerCase().includes("401");
|
||||
errorDetail.includes("Authentication failed") ||
|
||||
errorDetail.includes("Unauthorized") ||
|
||||
errorDetail.includes("Not authenticated") ||
|
||||
errorDetail.toLowerCase().includes("401");
|
||||
if (isAuthError) {
|
||||
toast({
|
||||
title: "Authentication error",
|
||||
@@ -307,6 +337,9 @@ export function useCopilotStream({
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetCopilotUsageQueryKey(),
|
||||
});
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
|
||||
@@ -11,6 +11,9 @@ import {
|
||||
|
||||
import { RefundModal } from "./RefundModal";
|
||||
import { CreditTransaction } from "@/lib/autogpt-server-api";
|
||||
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
|
||||
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
|
||||
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
import {
|
||||
Table,
|
||||
@@ -21,6 +24,32 @@ import {
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
|
||||
function CoPilotUsageSection() {
|
||||
const router = useRouter();
|
||||
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
|
||||
query: {
|
||||
select: (res) => res.data as CoPilotUsageStatus,
|
||||
refetchInterval: 30000,
|
||||
staleTime: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
if (isLoading || !usage) return null;
|
||||
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
|
||||
|
||||
return (
|
||||
<div className="my-6 space-y-4">
|
||||
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
|
||||
<div className="rounded-lg border border-neutral-200 p-4">
|
||||
<UsagePanelContent usage={usage} showBillingLink={false} />
|
||||
</div>
|
||||
<Button className="w-full" onClick={() => router.push("/copilot")}>
|
||||
Open CoPilot
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function CreditsPage() {
|
||||
const api = useBackendAPI();
|
||||
const {
|
||||
@@ -237,11 +266,13 @@ export default function CreditsPage() {
|
||||
</Button>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{/* CoPilot Usage Limits */}
|
||||
<CoPilotUsageSection />
|
||||
</div>
|
||||
|
||||
<div className="my-6 space-y-4">
|
||||
{/* Payment Portal */}
|
||||
|
||||
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
|
||||
<p className="text-neutral-600">
|
||||
You can manage your cards and see your payment history in the
|
||||
|
||||
@@ -1267,7 +1267,7 @@
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat Post",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"operationId": "postV2StreamChatPost",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1382,6 +1382,28 @@
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/usage": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Copilot Usage",
|
||||
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
|
||||
"operationId": "getV2GetCopilotUsage",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -8455,6 +8477,16 @@
|
||||
"title": "ClarifyingQuestion",
|
||||
"description": "A question that needs user clarification."
|
||||
},
|
||||
"CoPilotUsageStatus": {
|
||||
"properties": {
|
||||
"daily": { "$ref": "#/components/schemas/UsageWindow" },
|
||||
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["daily", "weekly"],
|
||||
"title": "CoPilotUsageStatus",
|
||||
"description": "Current usage status for a user across all windows."
|
||||
},
|
||||
"ContentType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -12190,6 +12222,16 @@
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
},
|
||||
"total_prompt_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Prompt Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"total_completion_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Completion Tokens",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -14587,6 +14629,25 @@
|
||||
"required": ["timezone"],
|
||||
"title": "UpdateTimezoneRequest"
|
||||
},
|
||||
"UsageWindow": {
|
||||
"properties": {
|
||||
"used": { "type": "integer", "title": "Used" },
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"title": "Limit",
|
||||
"description": "Maximum tokens allowed in this window. 0 means unlimited."
|
||||
},
|
||||
"resets_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Resets At"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["used", "limit", "resets_at"],
|
||||
"title": "UsageWindow",
|
||||
"description": "Usage within a single time window."
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
|
||||
@@ -288,6 +288,7 @@ const SidebarTrigger = React.forwardRef<
|
||||
ref={ref}
|
||||
data-sidebar="trigger"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={(event) => {
|
||||
onClick?.(event);
|
||||
toggleSidebar();
|
||||
|
||||
Reference in New Issue
Block a user