Compare commits

..

23 Commits

Author SHA1 Message Date
Zamil Majdy
88eaab2baa Merge remote-tracking branch 'origin/dev' into feat/github-cli-copilot 2026-03-17 06:17:03 +07:00
Zamil Majdy
4b0a445635 fix(copilot): remove implicit gh auth setup-git from sandbox creation
Remove the automatic GitHub credential helper configuration that ran on
every E2B sandbox connect/reconnect. This addressed a review concern
about implicitly giving AutoPilot full GitHub access without user
awareness or opt-in.

The bash_exec tool already injects GH_TOKEN/GITHUB_TOKEN per-command
for users who have connected their account via connect_integration,
which is the explicit opt-in path.
2026-03-17 00:36:51 +07:00
Zamil Majdy
36312d2c6e fix(backend/copilot): bust cache on OAuth refresh + persist dismissed state
- creds_manager: call _bust_copilot_cache after refresh_if_needed
  persists the refreshed token so the copilot cache doesn't serve
  a stale access token after silent refresh
- ConnectIntegrationTool: lift isDismissed state to parent so
  SetupRequirementsCard remounts don't re-enable the Proceed button;
  onComplete callback propagates the dismissed signal up
2026-03-16 17:10:18 +07:00
Zamil Majdy
d6d3b8d710 fix(copilot): address coderabbitai major issues — scope token description to E2B, guard cache-bust hook
- connect_integration.py: clarify that GH_TOKEN is injected per-command in
  E2B/cloud only; note that bubblewrap isolates network so retry won't work
- creds_manager._bust_copilot_cache: wrap _on_creds_changed in try/except so
  a failing hook doesn't turn successful create/update/delete into a 500
2026-03-16 15:52:40 +07:00
Zamil Majdy
17d8d0bf05 fix(backend/copilot): run gh auth setup-git once on sandbox connect/reconnect
Move git credential helper setup out of bash_exec (where it ran on
every command) and into _setup_e2b so it runs exactly once per sandbox
connect or reconnect. Non-fatal: logged at debug level on failure.
2026-03-16 15:45:18 +07:00
Zamil Majdy
5a2ab65f41 fix(backend/copilot): run gh auth setup-git once per sandbox session
Use grep to skip re-running if the credential helper is already
configured in ~/.gitconfig — only pays the cost on first command.
Agent can still call it manually if GH_TOKEN changes mid-session.
2026-03-16 15:42:54 +07:00
Zamil Majdy
81a318de3e feat(backend/copilot): improve GitHub OAuth UX and git auth
- Dynamic OAuth scopes: connect_integration tool now accepts a `scopes`
  param so the agent can request exactly the access it needs (e.g.
  `["repo", "read:org"]`); GitHub defaults to `["repo"]` so git
  push/pull works out of the box instead of public-data-only
- Lazy git auth: prepend `gh auth setup-git` on every E2B bash_exec
  when GH_TOKEN is present — git HTTPS clone/push/pull now work
  automatically without the agent needing to set this up manually
- Prefer broadest-scoped OAuth2 credential: sort repo-scoped tokens
  first so a stale public-data token is never picked over a full one
- Collapse SetupRequirementsCard to "Connected. Continuing…" after
  Proceed is clicked instead of leaving the full card visible
- Fix credential auto-select: don't silently pick the first token when
  multiple credentials exist — let the user choose via the dropdown
2026-03-16 15:26:14 +07:00
Zamil Majdy
62c8e8634b fix(copilot): patch _manager singleton directly in tests instead of class constructor
The module-level _manager singleton is created at import time, so patching
IntegrationCredentialsManager after import has no effect. Patch the _manager
attribute directly so get_provider_token uses the mock.
2026-03-16 06:32:36 +07:00
Zamil Majdy
b91c959cd9 fix(copilot): address remaining review findings
- creds_manager: fix TOCTOU in delete() — move get_creds_by_id inside the lock
- creds_manager: replace lazy import in _bust_copilot_cache with a
  register_creds_changed_hook() callback so creds_manager has no runtime
  dependency on the copilot module
- integration_creds: register invalidate_user_provider_cache at module import
  via register_creds_changed_hook() — eliminates the circular-import risk
- integration_creds: add module-level _manager singleton (avoids re-instantiating
  IntegrationCredentialsManager on every cache miss)
- integration_creds: document TTLCache asyncio-only thread-safety assumption
- connect_integration: defer GITHUB_OAUTH_IS_CONFIGURED evaluation to runtime
  with an lru_cache'd helper; importing the module no longer triggers Secrets()
- connect_integration: type missing_credentials dict with _CredentialEntry TypedDict
- connect_integration: cap reason field at 500 chars; add maxLength to JSON schema
- bash_exec: use 'user_id is not None' instead of truthy check
- connect_integration_test: add test for unauthenticated caller (requires_auth guard)
- bash_exec_test: add E2B path tests — token injected when user_id set, skipped
  when user_id is None
- ConnectIntegrationTool.tsx: sanitize LLM-controlled providerName fallback
  (trim + slice to 64 chars)
2026-03-16 06:21:44 +07:00
Zamil Majdy
5b95a2a1ef refactor(copilot): strongly type _PROVIDER_INFO with TypedDict
Replace dict[str, Any] with a _ProviderInfo TypedDict for provider
metadata entries, eliminating key/type drift as new providers are added.
2026-03-16 06:04:02 +07:00
Zamil Majdy
9c2a601167 refactor(copilot): simplify cache with cachetools.TTLCache, fix prompt wording
- Replace manual dict+sentinel cache with two TTLCache instances:
  _token_cache (5min TTL) and _null_cache (60s TTL)
- Remove _cache_set helper and _NO_TOKEN sentinel — TTLCache handles
  expiry and LRU eviction natively
- Update tests to use _token_cache/_null_cache directly; add TTL constant test
- Change _E2B_TOOL_NOTES from "GH_TOKEN is set" to "gh is pre-authenticated"
  so the AI doesn't attempt to read the env var directly
2026-03-16 00:16:26 +07:00
Zamil Majdy
b98e37bf23 refactor(copilot): DRY cache-bust helper, fast eviction test, unified JSON parse
Backend:
- Extract _bust_copilot_cache() in creds_manager.py; create/update/delete
  now each call it once instead of repeating the try/except ImportError block
- test_evicts_oldest_when_full: patch _CACHE_MAX_SIZE to 3 to avoid
  allocating 10 000 entries in CI; remove now-unused _CACHE_MAX_SIZE import

Frontend:
- Extract parseJson() helper shared by parseOutput and parseError in
  ConnectIntegrationTool.tsx, eliminating duplicated try/catch logic
2026-03-16 00:01:10 +07:00
Zamil Majdy
fec8924361 fix(copilot): bust token cache on update/delete, tighten except clause
- creds_manager.create/update/delete now all call invalidate_user_provider_cache
  after mutating credentials, so the next bash_exec always picks up the
  current state without waiting for TTL to expire
- Change broad `except Exception` to `except ImportError` in all three methods
  so real bugs inside invalidate_user_provider_cache are not silently swallowed
- delete() reads the provider before deletion so we know which cache key to evict
- Add tests for invalidate_user_provider_cache: removes sentinel/token entry,
  no-op when key absent, only removes the targeted key
2026-03-15 23:57:11 +07:00
Zamil Majdy
712aee7302 fix(copilot): warn on stale OAuth token fallback, document per-process cache
- Log at WARNING (not DEBUG) when OAuth refresh fails and we fall back to
  a potentially stale token, so operators can diagnose repeated auth failures
- Add multi-worker note to module docstring: _token_cache is process-local;
  each replica maintains its own cache (acceptable for current goal, but a
  shared cache would be needed for cross-replica efficiency)
2026-03-15 23:53:10 +07:00
Zamil Majdy
bef292033e fix(copilot): render error state in ConnectIntegrationTool
When part.state is 'output-error', show the error message from the
backend (ErrorResponse.message) in red text below the status line.
Without this, errors from unknown/unsupported providers were silently
discarded, leaving the user without any feedback.
2026-03-15 23:51:09 +07:00
Zamil Majdy
ec6974e3b8 fix(copilot): invalidate null cache on credential creation
When a user connects an integration, IntegrationCredentialsManager.create()
now calls invalidate_user_provider_cache() to remove any stale _NO_TOKEN
sentinel from the TTL cache. Without this, the first retry after connecting
would still return None for up to _NULL_CACHE_TTL (60 s).

The import is done lazily inside create() to avoid a circular import
between integrations.creds_manager and copilot.integration_creds.
2026-03-15 23:50:34 +07:00
Zamil Majdy
2ef5e2fe77 feat(copilot): bounded TTL cache with sentinel for integration creds
- Replace empty-string sentinel with explicit _NO_TOKEN = object() to
  avoid ambiguity with zero-length tokens
- Bound _token_cache to _CACHE_MAX_SIZE=10_000 entries; _cache_set()
  evicts oldest insertion-order entry when full
- Cache "not connected" results with _NULL_CACHE_TTL=60s (vs 300s for
  found tokens) to avoid a DB hit on every E2B bash_exec for users who
  haven't connected yet, while still picking up a new connection quickly
- Add integration_creds_test.py covering all cache paths, sentinel,
  eviction, OAuth2 preferred/fallback, DB exception, and env var injection
2026-03-15 23:46:05 +07:00
Zamil Majdy
0a8c7221ce fix(copilot): address all review findings
- prompting: rename _SDK_TOOL_NOTES → _E2B_TOOL_NOTES; pass it only to
  _get_cloud_sandbox_supplement() via new extra_notes param — local
  (bubblewrap) mode uses --unshare-net so gh CLI cannot reach GitHub
- integration_creds: cache None results with 60 s TTL (_NULL_CACHE_TTL)
  to avoid a DB hit on every E2B bash_exec for users without GitHub creds;
  found tokens still cached for 5 min (_TOKEN_CACHE_TTL)
- connect_integration: add cross-reference comment to PROVIDER_ENV_VARS
- ConnectIntegrationTool: use provider-specific credentialsLabel
  (e.g. "GitHub credentials" instead of "Integration credentials")
2026-03-15 23:35:25 +07:00
Zamil Majdy
840d1de636 refactor(copilot): move token injection to bash_exec + add integration_creds module
- Extract integration token lookup into backend/copilot/integration_creds.py:
  * Generic get_provider_token(user_id, provider) with 5-min TTL cache
  * get_integration_env_vars(user_id) loops over PROVIDER_ENV_VARS registry
  * Adding a new provider only requires a one-line PROVIDER_ENV_VARS entry

- Inject tokens lazily in bash_exec._execute_on_e2b (E2B has internet access;
  bubblewrap uses --unshare-net so gh CLI cannot reach GitHub regardless)

- Remove eager per-turn GH_TOKEN injection from sdk/service.py (wrong layer:
  bubblewrap is network-isolated, E2B injection now done per-command in bash_exec)

- Fix unsafe output.setup_info?.agent_name access in ConnectIntegrationTool

- Add connect_integration_test.py: unknown provider, known provider structure,
  reason in message, session_id propagation, case-insensitive provider slug
2026-03-15 23:29:33 +07:00
Zamil Majdy
ac55ab619b fix(copilot): address coderabbitai nitpicks
- Use `type Props` instead of `interface Props` in ConnectIntegrationTool
- Simplify parseOutput: skip stringify→parse round-trip for objects
- Document why requires_auth=True despite user_id not being used
2026-03-15 23:14:54 +07:00
Zamil Majdy
a8014d1e92 fix(copilot): address sentry — refresh expired OAuth tokens, handle object output in parseOutput
- service.py: use IntegrationCredentialsManager.refresh_if_needed() instead of
  raw IntegrationCredentialsStore so expired GitHub OAuth tokens are refreshed
  before injection; falls back to stale token on refresh failure to avoid
  breaking the turn entirely (lock=False to avoid blocking the turn)
- ConnectIntegrationTool.tsx: parseOutput now handles both string and
  already-parsed object inputs, matching the RunBlock helper pattern
  used elsewhere in the codebase
2026-03-15 23:06:57 +07:00
Zamil Majdy
7de13c7713 fix(copilot): address self-review — GH_TOKEN OAuth preference, unknown provider error, baseline note scope
- service.py: two-pass loop in _get_github_token_for_user() to genuinely
  prefer OAuth2 tokens over API keys; use creds.type discriminator instead
  of isinstance to match codebase style
- connect_integration.py: return ErrorResponse (not SetupRequirementsResponse)
  for unknown providers so the frontend renders a proper error instead of a
  blank broken card; trim "Integration" suffix from agent_name to avoid
  "Connect GitHub Integration" redundancy
- prompting.py: move GitHub CLI / connect_integration guidance from
  _SHARED_TOOL_NOTES (baseline+SDK) to _SDK_TOOL_NOTES (SDK-only) since
  baseline mode has no subprocess, no gh CLI, and no connect_integration tool
- ConnectIntegrationTool.tsx: simplify parseOutput to short-circuit when raw
  is not a string, removing unnecessary JSON.stringify round-trip
2026-03-15 23:00:46 +07:00
Zamil Majdy
9358b525a0 feat(copilot): inject GH_TOKEN and add connect_integration tool for missing GitHub credentials
When the user has connected GitHub, GH_TOKEN is automatically injected
into the Claude Agent SDK subprocess environment so `gh` CLI works
without any manual auth step.

When GitHub is not connected, the copilot can call the new
connect_integration(provider="github") MCP tool, which surfaces the
same credentials setup card used by GitHub blocks — letting the user
connect their account inline without leaving the chat.

- backend: _get_github_token_for_user() fetches the user's GitHub
  credentials (OAuth2 or API key) and injects GH_TOKEN + GITHUB_TOKEN
  into sdk_env before the Claude Agent SDK subprocess starts
- backend: ConnectIntegrationTool MCP tool returns a
  SetupRequirementsResponse for any known provider (github for now)
- backend: prompting.py documents the gh CLI / connect_integration
  flow in _SHARED_TOOL_NOTES so the copilot knows when to call it
- frontend: ConnectIntegrationTool component renders the existing
  SetupRequirementsCard with a tailored retry instruction
- frontend: MessagePartRenderer dispatches tool-connect_integration
  to the new component
2026-03-15 22:55:08 +07:00
44 changed files with 1350 additions and 2375 deletions

View File

@@ -5,14 +5,12 @@ on:
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
pull_request:
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
merge_group:

View File

@@ -120,6 +120,175 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs
integration_test:
runs-on: ubuntu-latest
needs: setup

View File

@@ -1,18 +1,14 @@
name: AutoGPT Platform - Full-stack CI
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
merge_group:
@@ -28,28 +24,42 @@ defaults:
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Enable corepack
run: corepack enable
- name: Set up Node
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install dependencies to populate cache
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
check-api-types:
name: check API types
runs-on: ubuntu-latest
types:
runs-on: big-boi
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
@@ -57,256 +67,70 @@ jobs:
with:
submodules: recursive
# ------------------------ Backend setup ------------------------
- name: Set up Backend - Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Set up Backend - Install Poetry
working-directory: autogpt_platform/backend
run: |
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Installing Poetry version ${POETRY_VERSION}"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
- name: Set up Backend - Set up dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Set up Backend - Install dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Set up Backend - Generate Prisma client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
- name: Set up Frontend - Export OpenAPI schema from Backend
working-directory: autogpt_platform/backend
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
# ------------------------ Frontend setup ------------------------
- name: Set up Frontend - Enable corepack
run: corepack enable
- name: Set up Frontend - Set up Node
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up Frontend - Install dependencies
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
- name: Restore dependencies cache
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up Frontend - Format OpenAPI schema
id: format-schema
run: pnpm prettier --write ./src/app/api/openapi.json
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after exporting the API schema."
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "\nIn the backend directory:"
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
echo "\nIn the frontend directory:"
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
echo "3. Run 'pnpm generate:api'"
echo "4. Run 'pnpm types'"
echo "5. Fix any TypeScript errors that may have been introduced"
echo "6. Commit and push your changes"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Set up Frontend - Generate API client
id: generate-api-client
run: pnpm orval --config ./orval.config.ts
# Continue with type generation & check even if there are schema changes
if: success() || (steps.format-schema.outcome == 'success')
- name: Check for TypeScript errors
- name: Run Typescript checks
run: pnpm types
if: success() || (steps.generate-api-client.outcome == 'success')
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs

View File

@@ -8,7 +8,7 @@ from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
@@ -27,12 +27,6 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
@@ -126,8 +120,6 @@ class SessionDetailResponse(BaseModel):
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
class SessionSummaryResponse(BaseModel):
@@ -215,7 +207,7 @@ async def list_sessions(
}
except Exception:
logger.warning(
"Failed to fetch processing status from Redis; defaulting to empty"
"Failed to fetch processing status from Redis; " "defaulting to empty"
)
return ListSessionsResponse(
@@ -237,7 +229,7 @@ async def list_sessions(
"/sessions",
)
async def create_session(
user_id: Annotated[str, Security(auth.get_user_id)],
user_id: Annotated[str, Depends(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -356,7 +348,7 @@ async def update_session_title_route(
)
async def get_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
@@ -397,10 +389,6 @@ async def get_session(
last_message_id=last_message_id,
)
# Sum token usage from session
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
@@ -408,25 +396,6 @@ async def get_session(
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
@router.get(
"/usage",
)
async def get_copilot_usage(
user_id: Annotated[str, Security(auth.get_user_id)],
) -> CoPilotUsageStatus:
"""Get CoPilot usage status for the authenticated user.
Returns current token usage vs limits for daily and weekly windows.
"""
return await get_usage_status(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
@@ -436,7 +405,7 @@ async def get_copilot_usage(
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
@@ -481,7 +450,7 @@ async def cancel_session_task(
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str = Security(auth.get_user_id),
user_id: str | None = Depends(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
@@ -498,7 +467,7 @@ async def stream_chat_post(
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Authenticated user ID.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
@@ -507,7 +476,9 @@ async def stream_chat_post(
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id, "user_id": user_id}
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
@@ -525,18 +496,6 @@ async def stream_chat_post(
},
)
# Pre-turn rate limit check (token-based).
# check_rate_limit short-circuits internally when both limits are 0.
if user_id:
try:
await check_rate_limit(
user_id=user_id,
daily_token_limit=config.daily_token_limit,
weekly_token_limit=config.weekly_token_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
@@ -771,7 +730,7 @@ async def stream_chat_post(
)
async def resume_session_stream(
session_id: str,
user_id: str = Security(auth.get_user_id),
user_id: str | None = Depends(auth.get_user_id),
):
"""
Resume an active stream for a session.

View File

@@ -1,6 +1,5 @@
"""Tests for chat API routes: session title update, file attachment validation, usage, rate limiting, and suggested prompts."""
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import fastapi
@@ -252,156 +251,6 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
assert call_kwargs["where"]["isDeleted"] is False
# ─── Rate limit → 429 ─────────────────────────────────────────────────
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
# Ensure the rate-limit branch is entered by setting a non-zero limit.
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
assert "daily" in response.json()["detail"].lower()
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
resets_at = datetime.now(UTC) + timedelta(days=3)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("weekly", resets_at),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"].lower()
assert "weekly" in detail
assert "resets in" in detail
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
"""The 429 response detail should include the human-readable reset time."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_stream_internals(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded(
"daily", datetime.now(UTC) + timedelta(hours=2, minutes=30)
),
)
response = client.post(
"/sessions/sess-1/stream",
json={"message": "hello"},
)
assert response.status_code == 429
detail = response.json()["detail"]
assert "2h" in detail
assert "Resets in" in detail
# ─── Usage endpoint ───────────────────────────────────────────────────
def _mock_usage(
mocker: pytest_mock.MockerFixture,
*,
daily_used: int = 500,
weekly_used: int = 2000,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
new_callable=AsyncMock,
return_value=status,
)
def test_usage_returns_daily_and_weekly(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""GET /usage returns daily and weekly usage."""
mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000)
mocker.patch.object(chat_routes.config, "daily_token_limit", 10000)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000)
response = client.get("/usage")
assert response.status_code == 200
data = response.json()
assert data["daily"]["used"] == 500
assert data["weekly"]["used"] == 2000
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=10000,
weekly_token_limit=50000,
)
def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
response = client.get("/usage")
assert response.status_code == 200
mock_get.assert_called_once_with(
user_id=test_user_id,
daily_token_limit=99999,
weekly_token_limit=77777,
)
def test_usage_rejects_unauthenticated_request() -> None:
"""GET /usage should return 401 when no valid JWT is provided."""
unauthenticated_app = fastapi.FastAPI()
unauthenticated_app.include_router(chat_routes.router)
unauthenticated_client = fastapi.testclient.TestClient(unauthenticated_app)
response = unauthenticated_client.get("/usage")
assert response.status_code == 401
# ─── Suggested prompts endpoint ──────────────────────────────────────

View File

@@ -36,15 +36,13 @@ from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
_get_openai_client,
client,
config,
)
from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
@@ -91,7 +89,7 @@ async def _compress_session_messages(
result = await compress_context(
messages=messages_dict,
model=config.model,
client=_get_openai_client(),
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
@@ -223,10 +221,6 @@ async def stream_chat_completion_baseline(
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
# Token usage accumulators — populated from streaming chunks
turn_prompt_tokens = 0
turn_completion_tokens = 0
_stream_error = False # Track whether an error occurred during streaming
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
@@ -238,31 +232,16 @@ async def stream_chat_completion_baseline(
model=config.model,
messages=openai_messages,
stream=True,
stream_options={"include_usage": True},
)
if tools:
create_kwargs["tools"] = tools
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
# Capture token usage from the streaming chunk.
# OpenRouter normalises all providers into OpenAI format
# where prompt_tokens already includes cached tokens
# (unlike Anthropic's native API). Use += to sum all
# tool-call rounds since each API call is independent.
# NOTE: stream_options={"include_usage": True} is not
# universally supported — some providers (Mistral, Llama
# via OpenRouter) always return chunk.usage=None. When
# that happens, tokens stay 0 and the tiktoken fallback
# below activates. Fail-open: one round is estimated.
if chunk.usage:
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
@@ -415,7 +394,6 @@ async def stream_chat_completion_baseline(
)
except Exception as e:
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
@@ -433,49 +411,6 @@ async def stream_chat_completion_baseline(
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Fallback: estimate tokens via tiktoken when the provider does
# not honour stream_options={"include_usage": True}.
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
# NOTE: This estimates one round's prompt tokens. Multi-round tool-calling
# turns consume prompt tokens on each API call, so the total is underestimated.
# Skip fallback when an error occurred and no output was produced —
# charging rate-limit tokens for completely failed requests is unfair.
if (
turn_prompt_tokens == 0
and turn_completion_tokens == 0
and not (_stream_error and not assistant_text)
):
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
)
turn_completion_tokens = estimate_token_count_str(
assistant_text, model=config.model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
"prompt=%d, completion=%d",
turn_prompt_tokens,
turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
# cannot break out cache_read/cache_creation weights. Users on the
# baseline path may be slightly over-counted vs the SDK path.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
log_prefix="[Baseline]",
)
# Persist assistant response
if assistant_text:
session.messages.append(
@@ -486,16 +421,4 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=turn_prompt_tokens + turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -70,27 +70,6 @@ class ChatConfig(BaseSettings):
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Rate limiting — token-based limits per day and per week.
# Per-turn token cost varies with context size: ~10-15K for early turns,
# ~30-50K mid-session, up to ~100K pre-compaction. Average across a
# session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily
# allows ~70-100 turns/day.
# Checked at the HTTP layer (routes.py) before each turn.
#
# TODO: These are deploy-time constants applied identically to every user.
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
# must move to the database (e.g., a UserPlan table) and get_usage_status /
# check_rate_limit would look up each user's specific limits instead of
# reading config.daily_token_limit / config.weekly_token_limit.
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
)
weekly_token_limit: int = Field(
default=12_500_000,
description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,

View File

@@ -0,0 +1,162 @@
"""Integration credential lookup with per-process TTL cache.
Provides token retrieval for connected integrations so that copilot tools
(e.g. bash_exec) can inject auth tokens into the execution environment without
hitting the database on every command.
Cache semantics (handled automatically by TTLCache):
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
for users who have credentials and are running many bash commands.
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
on every E2B command for users who haven't connected an account yet, while
still picking up a newly-connected account within one minute.
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
least-recently-used entry when the limit is reached.
Multi-worker note: both caches are in-process only. Each worker/replica
maintains its own independent cache, so a credential fetch may be duplicated
across processes. This is acceptable for the current goal (reduce DB hits per
session per-process), but if cache efficiency across replicas becomes important
a shared cache (e.g. Redis) should be used instead.
"""
import logging
from typing import cast
from cachetools import TTLCache
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
register_creds_changed_hook,
)
logger = logging.getLogger(__name__)
# Maps provider slug → env var names to inject when the provider is connected.
# Add new providers here when adding integration support.
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
# must be updated when adding a new provider.
PROVIDER_ENV_VARS: dict[str, list[str]] = {
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
}
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
_CACHE_MAX_SIZE = 10_000
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
# because all callers (get_provider_token, invalidate_user_provider_cache) run
# exclusively on the asyncio event loop. There are no await points between a
# cache read and its corresponding write within any function, so no concurrent
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
# this path, a threading.RLock should be wrapped around these caches.
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
)
# Separate cache for "no credentials" results with a shorter TTL.
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
# entries. This avoids a lazy import inside creds_manager and eliminates the
# circular-import risk.
register_creds_changed_hook(invalidate_user_provider_cache)
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
# on every cache-miss call to get_provider_token().
_manager = IntegrationCredentialsManager()
async def get_provider_token(user_id: str, provider: str) -> str | None:
"""Return the user's access token for *provider*, or ``None`` if not connected.
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
command for users who haven't connected yet, while still picking up a
newly-connected account within one minute.
"""
cache_key = (user_id, provider)
if cache_key in _null_cache:
return None
if cached := _token_cache.get(cache_key):
return cached
manager = _manager
try:
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
except Exception:
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
return None
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
# full git access, while a public-data-only token lacks push/pull permission.
# lock=False — background injection; not worth a distributed lock acquisition.
oauth2_creds = sorted(
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
fresh = await manager.refresh_if_needed(
user_id, cast(OAuth2Credentials, creds), lock=False
)
token = fresh.access_token.get_secret_value()
except Exception:
logger.warning(
"Failed to refresh %s OAuth token for user %s; "
"falling back to potentially stale token",
provider,
user_id,
)
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
_token_cache[cache_key] = token
return token
# Pass 2: fall back to API key (no expiry, no refresh needed).
for creds in creds_list:
if creds.type == "api_key":
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
return None
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
"""Return env vars for all providers the user has connected.
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
Only providers with a stored credential contribute entries.
"""
env: dict[str, str] = {}
for provider, var_names in PROVIDER_ENV_VARS.items():
token = await get_provider_token(user_id, provider)
if token:
for var in var_names:
env[var] = token
return env

View File

@@ -0,0 +1,193 @@
"""Tests for integration_creds — TTL cache and token lookup paths."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_null_cache,
_token_cache,
get_integration_env_vars,
get_provider_token,
invalidate_user_provider_cache,
)
from backend.data.model import APIKeyCredentials, OAuth2Credentials
_USER = "user-integration-creds-test"
_PROVIDER = "github"
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
return APIKeyCredentials(
id="creds-api-key",
provider=_PROVIDER,
api_key=SecretStr(key),
title="Test API Key",
expires_at=None,
)
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
return OAuth2Credentials(
id="creds-oauth2",
provider=_PROVIDER,
title="Test OAuth",
access_token=SecretStr(token),
refresh_token=SecretStr("test-refresh"),
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
)
@pytest.fixture(autouse=True)
def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
class TestInvalidateUserProviderCache:
def test_removes_token_entry(self):
key = (_USER, _PROVIDER)
_token_cache[key] = "tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _token_cache
def test_removes_null_entry(self):
key = (_USER, _PROVIDER)
_null_cache[key] = True
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _null_cache
def test_noop_when_key_not_cached(self):
# Should not raise even when there is no cache entry.
invalidate_user_provider_cache("no-such-user", _PROVIDER)
def test_only_removes_targeted_key(self):
other_key = ("other-user", _PROVIDER)
_token_cache[other_key] = "other-tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_cached_token_without_db_hit(self):
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "cached-tok"
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_none_for_null_cached_provider(self):
_null_cache[(_USER, _PROVIDER)] = True
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_api_key_creds_returned_and_cached(self):
api_creds = _make_api_key_creds("my-api-key")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "my-api-key"
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_preferred_over_api_key(self):
oauth_creds = _make_oauth2_creds("oauth-tok")
api_creds = _make_api_key_creds("api-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
return_value=[api_creds, oauth_creds]
)
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "stale-oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
assert _null_cache.get((_USER, _PROVIDER)) is True
@pytest.mark.asyncio(loop_scope="session")
async def test_db_exception_returns_none_without_caching(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
side_effect=RuntimeError("db down")
)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
# DB errors are not cached — next call will retry
assert (_USER, _PROVIDER) not in _token_cache
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
"""Verify the TTL constants are set correctly for each cache."""
assert _null_cache.ttl == _NULL_CACHE_TTL
assert _token_cache.ttl == _TOKEN_CACHE_TTL
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):
_token_cache[(_USER, "github")] = "gh-tok"
result = await get_integration_env_vars(_USER)
for var in PROVIDER_ENV_VARS["github"]:
assert result[var] == "gh-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_dict_when_no_credentials(self):
_null_cache[(_USER, "github")] = True
result = await get_integration_env_vars(_USER)
assert result == {}

View File

@@ -73,9 +73,6 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Cache breakdown (Anthropic-specific; zero for non-Anthropic models)
cache_read_tokens: int = 0
cache_creation_tokens: int = 0
class ChatSessionInfo(BaseModel):
@@ -101,10 +98,7 @@ class ChatSessionInfo(BaseModel):
prisma_session.successfulAgentSchedules, default={}
)
# Calculate usage from token counts.
# NOTE: Per-turn cache_read_tokens / cache_creation_tokens breakdown
# is lost after persistence — the DB only stores aggregate prompt and
# completion totals. This is a known limitation.
# Calculate usage from token counts
usage = []
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
usage.append(

View File

@@ -95,6 +95,25 @@ Example — committing an image file to GitHub:
All tasks must run in the foreground.
"""
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
- If the token changes mid-session (e.g. user reconnects with a new token),
run `gh auth setup-git` to re-register the credential helper.
- If `gh` or `git` fails with an authentication error (e.g. "authentication
required", "could not read Username", or exit code 128), call
`connect_integration(provider="github")` to surface the GitHub credentials
setup card so the user can connect their account. Once connected, retry
the operation.
- For operations that need broader access (e.g. private org repos, GitHub
Actions), pass the required scopes: e.g.
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
@@ -105,6 +124,7 @@ def _build_storage_supplement(
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
extra_notes: str = "",
) -> str:
"""Build storage/filesystem supplement for a specific environment.
@@ -119,6 +139,7 @@ def _build_storage_supplement(
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
extra_notes: Environment-specific notes appended after shared notes
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
@@ -152,12 +173,16 @@ def _build_storage_supplement(
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}"""
{_SHARED_TOOL_NOTES}{extra_notes}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
"""Local ephemeral storage (files lost between turns).
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
like gh will not work — no integration env-var notes are included.
"""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
@@ -175,7 +200,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
"""Cloud persistent sandbox (files survive across turns in session).
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
injected per command in bash_exec — include the CLI guidance notes.
"""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
@@ -190,6 +219,7 @@ def _get_cloud_sandbox_supplement() -> str:
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
extra_notes=_E2B_TOOL_NOTES,
)

View File

@@ -1,266 +0,0 @@
"""CoPilot rate limiting based on token usage.
Uses Redis fixed-window counters to track per-user token consumption
with configurable daily and weekly limits. Daily windows reset at
midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00
UTC). Fails open when Redis is unavailable to avoid blocking users.
"""
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Redis key prefixes
_USAGE_KEY_PREFIX = "copilot:usage"
class UsageWindow(BaseModel):
"""Usage within a single time window."""
used: int
limit: int = Field(
description="Maximum tokens allowed in this window. 0 means unlimited."
)
resets_at: datetime
class CoPilotUsageStatus(BaseModel):
"""Current usage status for a user across all windows."""
daily: UsageWindow
weekly: UsageWindow
class RateLimitExceeded(Exception):
"""Raised when a user exceeds their CoPilot usage limit."""
def __init__(self, window: str, resets_at: datetime):
self.window = window
self.resets_at = resets_at
delta = resets_at - datetime.now(UTC)
total_secs = delta.total_seconds()
if total_secs <= 0:
time_str = "now"
else:
hours = int(total_secs // 3600)
minutes = int((total_secs % 3600) // 60)
time_str = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m"
super().__init__(
f"You've reached your {window} usage limit. Resets in {time_str}."
)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
Args:
user_id: The user's ID.
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
Returns:
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
daily_used = 0
weekly_used = 0
try:
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
daily_used = int(daily_raw or 0)
weekly_used = int(weekly_raw or 0)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for usage status, returning zeros")
return CoPilotUsageStatus(
daily=UsageWindow(
used=daily_used,
limit=daily_token_limit,
resets_at=_daily_reset_time(now=now),
),
weekly=UsageWindow(
used=weekly_used,
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
)
async def check_rate_limit(
user_id: str,
daily_token_limit: int,
weekly_token_limit: int,
) -> None:
"""Check if user is within rate limits. Raises RateLimitExceeded if not.
This is a pre-turn soft check. The authoritative usage counter is updated
by ``record_token_usage()`` after the turn completes. Under concurrency,
two parallel turns may both pass this check against the same snapshot.
This is acceptable because token-based limits are approximate by nature
(the exact token count is unknown until after generation).
Fails open: if Redis is unavailable, allows the request.
"""
# Short-circuit: when both limits are 0 (unlimited) skip the Redis
# round-trip entirely.
if daily_token_limit <= 0 and weekly_token_limit <= 0:
return
now = datetime.now(UTC)
try:
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
daily_used = int(daily_raw or 0)
weekly_used = int(weekly_raw or 0)
except (RedisError, ConnectionError, OSError):
logger.warning("Redis unavailable for rate limit check, allowing request")
return
# Worst-case overshoot: N concurrent requests × ~15K tokens each.
if daily_token_limit > 0 and daily_used >= daily_token_limit:
raise RateLimitExceeded("daily", _daily_reset_time(now=now))
if weekly_token_limit > 0 and weekly_used >= weekly_token_limit:
raise RateLimitExceeded("weekly", _weekly_reset_time(now=now))
async def record_token_usage(
user_id: str,
prompt_tokens: int,
completion_tokens: int,
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
) -> None:
"""Record token usage for a user across all windows.
Uses cost-weighted counting so cached tokens don't unfairly penalise
multi-turn conversations. Anthropic's pricing:
- uncached input: 100%
- cache creation: 25%
- cache read: 10%
- output: 100%
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
weighted_input = (
prompt_tokens
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
if total <= 0:
return
raw_total = (
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
)
now = datetime.now(UTC)
try:
redis = await get_redis_async()
# transaction=False: these are independent INCRBY+EXPIRE pairs on
# separate keys — no cross-key atomicity needed. Skipping
# MULTI/EXEC avoids the overhead. If the connection drops between
# INCRBY and EXPIRE the key survives until the next date-based key
# rotation (daily/weekly), so the memory-leak risk is negligible.
pipe = redis.pipeline(transaction=False)
# Daily counter (expires at next midnight UTC)
d_key = _daily_key(user_id, now=now)
pipe.incrby(d_key, total)
seconds_until_daily_reset = int(
(_daily_reset_time(now=now) - now).total_seconds()
)
pipe.expire(d_key, max(seconds_until_daily_reset, 1))
# Weekly counter (expires end of week)
w_key = _weekly_key(user_id, now=now)
pipe.incrby(w_key, total)
seconds_until_weekly_reset = int(
(_weekly_reset_time(now=now) - now).total_seconds()
)
pipe.expire(w_key, max(seconds_until_weekly_reset, 1))
await pipe.execute()
except (RedisError, ConnectionError, OSError):
logger.warning(
"Redis unavailable for recording token usage (tokens=%d)",
total,
)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
def _daily_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
return f"{_USAGE_KEY_PREFIX}:daily:{user_id}:{now.strftime('%Y-%m-%d')}"
def _weekly_key(user_id: str, now: datetime | None = None) -> str:
if now is None:
now = datetime.now(UTC)
year, week, _ = now.isocalendar()
return f"{_USAGE_KEY_PREFIX}:weekly:{user_id}:{year}-W{week:02d}"
def _daily_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current daily window resets (next midnight UTC)."""
if now is None:
now = datetime.now(UTC)
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
def _weekly_reset_time(now: datetime | None = None) -> datetime:
"""Calculate when the current weekly window resets (next Monday 00:00 UTC)."""
if now is None:
now = datetime.now(UTC)
days_until_monday = (7 - now.weekday()) % 7 or 7
return now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(
days=days_until_monday
)

View File

@@ -1,334 +0,0 @@
"""Unit tests for CoPilot rate limiting."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from redis.exceptions import RedisError
from .rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
check_rate_limit,
get_usage_status,
record_token_usage,
)
_USER = "test-user-rl"
# ---------------------------------------------------------------------------
# RateLimitExceeded
# ---------------------------------------------------------------------------
class TestRateLimitExceeded:
def test_message_contains_window_name(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1))
assert "daily" in str(exc)
def test_message_contains_reset_time(self):
exc = RateLimitExceeded(
"weekly", datetime.now(UTC) + timedelta(hours=2, minutes=30)
)
msg = str(exc)
# Allow for slight timing drift (29m or 30m)
assert "2h " in msg
assert "Resets in" in msg
def test_message_minutes_only_when_under_one_hour(self):
exc = RateLimitExceeded("daily", datetime.now(UTC) + timedelta(minutes=15))
msg = str(exc)
assert "Resets in" in msg
# Should not have "0h"
assert "0h" not in msg
def test_message_says_now_when_resets_at_is_in_the_past(self):
"""Negative delta (clock skew / stale TTL) should say 'now', not '-1h -30m'."""
exc = RateLimitExceeded("daily", datetime.now(UTC) - timedelta(minutes=5))
assert "Resets in now" in str(exc)
# ---------------------------------------------------------------------------
# get_usage_status
# ---------------------------------------------------------------------------
class TestGetUsageStatus:
@pytest.mark.asyncio
async def test_returns_redis_values(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", "2000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert isinstance(status, CoPilotUsageStatus)
assert status.daily.used == 500
assert status.daily.limit == 10000
assert status.weekly.used == 2000
assert status.weekly.limit == 50000
@pytest.mark.asyncio
async def test_returns_zeros_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_partial_none_daily_counter(self):
"""Daily counter is None (new day), weekly has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[None, "3000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 0
assert status.weekly.used == 3000
@pytest.mark.asyncio
async def test_partial_none_weekly_counter(self):
"""Weekly counter is None (start of week), daily has usage."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["500", None])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert status.daily.used == 500
assert status.weekly.used == 0
@pytest.mark.asyncio
async def test_resets_at_daily_is_next_midnight_utc(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["0", "0"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
status = await get_usage_status(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
now = datetime.now(UTC)
# Daily reset should be within 24h
assert status.daily.resets_at > now
assert status.daily.resets_at <= now + timedelta(hours=24, seconds=5)
# ---------------------------------------------------------------------------
# check_rate_limit
# ---------------------------------------------------------------------------
class TestCheckRateLimit:
@pytest.mark.asyncio
async def test_allows_when_under_limit(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_raises_when_daily_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["10000", "200"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_raises_when_weekly_limit_exceeded(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100", "50000"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_allows_when_redis_unavailable(self):
"""Fail-open: allow requests when Redis is down."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await check_rate_limit(
_USER, daily_token_limit=10000, weekly_token_limit=50000
)
@pytest.mark.asyncio
async def test_skips_check_when_limit_is_zero(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["999999", "999999"])
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — limits of 0 mean unlimited
await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0)
# ---------------------------------------------------------------------------
# record_token_usage
# ---------------------------------------------------------------------------
class TestRecordTokenUsage:
@staticmethod
def _make_pipeline_mock() -> MagicMock:
"""Create a pipeline mock with sync methods and async execute."""
pipe = MagicMock()
pipe.execute = AsyncMock(return_value=[])
return pipe
@pytest.mark.asyncio
async def test_increments_redis_counters(self):
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# Should call incrby twice (daily + weekly) with total=150
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 150 # daily
assert incrby_calls[1].args[1] == 150 # weekly
@pytest.mark.asyncio
async def test_skips_when_zero_tokens(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0)
# Should not call pipeline at all
mock_redis.pipeline.assert_not_called()
@pytest.mark.asyncio
async def test_sets_expire_on_both_keys(self):
"""Pipeline should call expire for both daily and weekly keys."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
expire_calls = mock_pipe.expire.call_args_list
assert len(expire_calls) == 2
# Daily key TTL should be positive (seconds until next midnight)
daily_ttl = expire_calls[0].args[1]
assert daily_ttl >= 1
# Weekly key TTL should be positive (seconds until next Monday)
weekly_ttl = expire_calls[1].args[1]
assert weekly_ttl >= 1
@pytest.mark.asyncio
async def test_handles_redis_failure_gracefully(self):
"""Should not raise when Redis is unavailable."""
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=ConnectionError("Redis down"),
):
# Should not raise
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
@pytest.mark.asyncio
async def test_cost_weighted_counting(self):
"""Cached tokens should be weighted: cache_read=10%, cache_create=25%."""
mock_pipe = self._make_pipeline_mock()
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
await record_token_usage(
_USER,
prompt_tokens=100, # uncached → 100
completion_tokens=50, # output → 50
cache_read_tokens=10000, # 10% → 1000
cache_creation_tokens=400, # 25% → 100
)
# Expected weighted total: 100 + 1000 + 100 + 50 = 1250
incrby_calls = mock_pipe.incrby.call_args_list
assert len(incrby_calls) == 2
assert incrby_calls[0].args[1] == 1250 # daily
assert incrby_calls[1].args[1] == 1250 # weekly
@pytest.mark.asyncio
async def test_handles_redis_error_during_pipeline_execute(self):
"""Should not raise when pipeline.execute() fails with RedisError."""
mock_pipe = self._make_pipeline_mock()
mock_pipe.execute = AsyncMock(side_effect=RedisError("Pipeline failed"))
mock_redis = AsyncMock()
mock_redis.pipeline = lambda **_kw: mock_pipe
with patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
):
# Should not raise — fail-open
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)

View File

@@ -186,43 +186,12 @@ class StreamToolOutputAvailable(StreamBaseResponse):
class StreamUsage(StreamBaseResponse):
"""Token usage statistics.
Emitted as an SSE comment so the Vercel AI SDK parser ignores it
(it uses z.strictObject() and rejects unknown event types).
Usage data is recorded server-side (session DB + Redis counters).
"""
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
prompt_tokens: int = Field(
...,
serialization_alias="promptTokens",
description="Number of uncached prompt tokens",
)
completion_tokens: int = Field(
...,
serialization_alias="completionTokens",
description="Number of completion tokens",
)
total_tokens: int = Field(
...,
serialization_alias="totalTokens",
description="Total number of tokens (raw, not weighted)",
)
cache_read_tokens: int = Field(
default=0,
serialization_alias="cacheReadTokens",
description="Prompt tokens served from cache (10% cost)",
)
cache_creation_tokens: int = Field(
default=0,
serialization_alias="cacheCreationTokens",
description="Prompt tokens written to cache (25% cost)",
)
def to_sse(self) -> str:
"""Emit as SSE comment so the AI SDK parser ignores it."""
return f": usage {self.model_dump_json(exclude_none=True, by_alias=True)}\n\n"
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
class StreamError(StreamBaseResponse):

View File

@@ -55,14 +55,12 @@ from ..response_model import (
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
StreamUsage,
)
from ..service import (
_build_system_prompt,
_generate_session_title,
_is_langfuse_configured,
)
from ..token_tracking import persist_and_record_usage
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
@@ -738,13 +736,6 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
# Make sure there is no more code between the lock acquisition and try-block.
# Token usage accumulators — populated from ResultMessage at end of turn
turn_prompt_tokens = 0 # uncached input tokens only
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
turn_cost_usd: float | None = None
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
@@ -778,7 +769,7 @@ async def stream_chat_completion_sdk(
)
return None
try:
return await get_or_create_sandbox(
sandbox = await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
@@ -792,7 +783,9 @@ async def stream_chat_completion_sdk(
e2b_err,
exc_info=True,
)
return None
return None
return sandbox
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""
@@ -1121,7 +1114,7 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details and capture token usage
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
@@ -1140,33 +1133,6 @@ async def stream_chat_completion_sdk(
sdk_msg.result or "(no error message provided)",
)
# Capture token usage from ResultMessage.
# Anthropic reports cached tokens separately:
# input_tokens = uncached only
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
turn_prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
turn_cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
)
turn_cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
)
turn_completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, cache_create=%d, output=%d",
log_prefix,
turn_prompt_tokens,
turn_cache_read_tokens,
turn_cache_creation_tokens,
turn_completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
turn_cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# When compaction ends, sync TranscriptBuilder with the
# CLI's active context so they stay identical.
@@ -1383,26 +1349,6 @@ async def stream_chat_completion_sdk(
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Emit token usage to the client (must be in try to reach SSE stream).
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
# total_tokens = prompt (uncached input) + completion (output).
# Cache tokens are tracked separately and excluded from total
# so that the semantics match the baseline path (OpenRouter)
# which folds cache into prompt_tokens. Keeping total_tokens
# = prompt + completion everywhere makes cross-path comparisons
# and session-level aggregation consistent.
total_tokens = turn_prompt_tokens + turn_completion_tokens
yield StreamUsage(
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
@@ -1467,20 +1413,6 @@ async def stream_chat_completion_sdk(
except Exception:
logger.warning("OTEL context teardown failed", exc_info=True)
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).

View File

@@ -28,24 +28,10 @@ logger = logging.getLogger(__name__)
config = ChatConfig()
settings = Settings()
_client: LangfuseAsyncOpenAI | None = None
_langfuse = None
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
def _get_openai_client() -> LangfuseAsyncOpenAI:
global _client
if _client is None:
_client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
return _client
def _get_langfuse():
global _langfuse
if _langfuse is None:
_langfuse = get_client()
return _langfuse
langfuse = get_client()
# Default system prompt used when Langfuse is not configured
# Provides minimal baseline tone and personality - all workflow, tools, and
@@ -98,7 +84,7 @@ async def _get_system_prompt_template(context: str) -> str:
else "latest"
)
prompt = await asyncio.to_thread(
_get_langfuse().get_prompt,
langfuse.get_prompt,
config.langfuse_prompt_name,
label=label,
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
@@ -172,7 +158,7 @@ async def _generate_session_title(
"environment": settings.config.app_env.value,
}
response = await _get_openai_client().chat.completions.create(
response = await client.chat.completions.create(
model=config.title_model,
messages=[
{

View File

@@ -1,93 +0,0 @@
"""Shared token-usage persistence and rate-limit recording.
Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
1. Append a ``Usage`` record to the session.
2. Log the turn's token counts.
3. Record weighted usage in Redis for rate-limiting.
This module extracts that common logic so both paths stay in sync.
"""
import logging
from .model import ChatSession, Usage
from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
async def persist_and_record_usage(
*,
session: ChatSession | None,
user_id: str | None,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
log_prefix: str = "",
cost_usd: float | str | None = None,
) -> int:
"""Persist token usage to session and record for rate limiting.
Args:
session: The chat session to append usage to (may be None on error).
user_id: User ID for rate-limit counters (skipped if None).
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (Anthropic only).
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
Returns:
The computed total_tokens (prompt + completion; cache excluded).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
if prompt_tokens <= 0 and completion_tokens <= 0:
return 0
# total_tokens = prompt + completion. Cache tokens are tracked
# separately and excluded from total so both baseline and SDK
# paths share the same semantics.
total_tokens = prompt_tokens + completion_tokens
if session is not None:
session.usage.append(
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
)
if cache_read_tokens or cache_creation_tokens:
logger.info(
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
)
else:
logger.info(
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
f"completion={completion_tokens}, total={total_tokens}"
)
if user_id:
try:
await record_token_usage(
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
return total_tokens

View File

@@ -1,281 +0,0 @@
"""Unit tests for token_tracking.persist_and_record_usage.
Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
calling conventions, session persistence, and rate-limit recording.
"""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from .model import ChatSession, Usage
from .token_tracking import persist_and_record_usage
def _make_session() -> ChatSession:
"""Return a minimal in-memory ChatSession for testing."""
return ChatSession(
session_id="sess-test",
user_id="user-test",
title=None,
messages=[],
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
# ---------------------------------------------------------------------------
# Return value / total_tokens semantics
# ---------------------------------------------------------------------------
class TestTotalTokens:
@pytest.mark.asyncio
async def test_returns_prompt_plus_completion(self):
"""total_tokens = prompt + completion (cache excluded from total)."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=300,
completion_tokens=200,
)
assert total == 500
@pytest.mark.asyncio
async def test_returns_zero_when_no_tokens(self):
"""Returns 0 early when both prompt and completion are zero."""
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=0,
completion_tokens=0,
)
assert total == 0
@pytest.mark.asyncio
async def test_cache_tokens_excluded_from_total(self):
"""Cache tokens are stored separately and not added to total_tokens."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=5000,
cache_creation_tokens=200,
)
# total = prompt + completion only (5000 + 200 cache excluded)
assert total == 150
@pytest.mark.asyncio
async def test_baseline_path_no_cache(self):
"""Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id="u1",
prompt_tokens=1000,
completion_tokens=400,
log_prefix="[Baseline]",
)
assert total == 1400
@pytest.mark.asyncio
async def test_sdk_path_with_cache(self):
"""SDK (Anthropic) path passes cache tokens; total still = prompt + completion."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id="u2",
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=8000,
cache_creation_tokens=400,
log_prefix="[SDK]",
cost_usd=0.0015,
)
assert total == 300
# ---------------------------------------------------------------------------
# Session persistence
# ---------------------------------------------------------------------------
class TestSessionPersistence:
@pytest.mark.asyncio
async def test_appends_usage_to_session(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
session=session,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
assert len(session.usage) == 1
usage: Usage = session.usage[0]
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.cache_read_tokens == 0
assert usage.cache_creation_tokens == 0
@pytest.mark.asyncio
async def test_appends_cache_breakdown_to_session(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
session=session,
user_id=None,
prompt_tokens=200,
completion_tokens=80,
cache_read_tokens=3000,
cache_creation_tokens=500,
)
usage: Usage = session.usage[0]
assert usage.cache_read_tokens == 3000
assert usage.cache_creation_tokens == 500
@pytest.mark.asyncio
async def test_multiple_turns_append_multiple_records(self):
session = _make_session()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
await persist_and_record_usage(
session=session, user_id=None, prompt_tokens=100, completion_tokens=50
)
await persist_and_record_usage(
session=session, user_id=None, prompt_tokens=200, completion_tokens=70
)
assert len(session.usage) == 2
@pytest.mark.asyncio
async def test_none_session_does_not_raise(self):
"""When session is None (e.g. error path), no exception should be raised."""
with patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
):
total = await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
assert total == 150
@pytest.mark.asyncio
async def test_no_append_when_zero_tokens(self):
"""When tokens are zero, function returns early — session unchanged."""
session = _make_session()
total = await persist_and_record_usage(
session=session,
user_id=None,
prompt_tokens=0,
completion_tokens=0,
)
assert total == 0
assert len(session.usage) == 0
# ---------------------------------------------------------------------------
# Rate-limit recording
# ---------------------------------------------------------------------------
class TestRateLimitRecording:
@pytest.mark.asyncio
async def test_calls_record_token_usage_when_user_id_present(self):
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id="user-abc",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
)
mock_record.assert_awaited_once_with(
user_id="user-abc",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
)
@pytest.mark.asyncio
async def test_skips_record_when_user_id_is_none(self):
"""Anonymous sessions should not create Redis keys."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id=None,
prompt_tokens=100,
completion_tokens=50,
)
mock_record.assert_not_awaited()
@pytest.mark.asyncio
async def test_record_failure_does_not_raise(self):
"""A Redis error in record_token_usage should be swallowed (fail-open)."""
mock_record = AsyncMock(side_effect=ConnectionError("Redis down"))
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
# Should not raise
total = await persist_and_record_usage(
session=None,
user_id="user-xyz",
prompt_tokens=100,
completion_tokens=50,
)
assert total == 150
@pytest.mark.asyncio
async def test_skips_record_when_zero_tokens(self):
"""Returns 0 before calling record_token_usage when tokens are zero."""
mock_record = AsyncMock()
with patch(
"backend.copilot.token_tracking.record_token_usage",
new=mock_record,
):
await persist_and_record_usage(
session=None,
user_id="user-abc",
prompt_tokens=0,
completion_tokens=0,
)
mock_record.assert_not_awaited()

View File

@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .connect_integration import ConnectIntegrationTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
"connect_integration": ConnectIntegrationTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),

View File

@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.integration_creds import get_integration_env_vars
from backend.copilot.model import ChatSession
from .base import BaseTool
@@ -96,7 +97,9 @@ class BashExecTool(BaseTool):
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
return await self._execute_on_e2b(
sandbox, command, timeout, session_id, user_id
)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
@@ -133,14 +136,27 @@ class BashExecTool(BaseTool):
command: str,
timeout: int,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run()."""
"""Execute *command* on the E2B sandbox via commands.run().
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
for any user with connected accounts. E2B has full internet access, so
CLI tools like ``gh`` work without manual authentication.
"""
envs: dict[str, str] = {
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
}
if user_id is not None:
integration_env = await get_integration_env_vars(user_id)
envs.update(integration_env)
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
envs=envs,
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",

View File

@@ -0,0 +1,78 @@
"""Tests for BashExecTool — E2B path with token injection."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from ._test_data import make_session
from .bash_exec import BashExecTool
from .models import BashExecResponse
_USER = "user-bash-exec-test"
def _make_tool() -> BashExecTool:
return BashExecTool()
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
result = MagicMock()
result.exit_code = exit_code
result.stdout = stdout
result.stderr = stderr
sandbox = MagicMock()
sandbox.commands.run = AsyncMock(return_value=result)
return sandbox
class TestBashExecE2BTokenInjection:
@pytest.mark.asyncio(loop_scope="session")
async def test_token_injected_when_user_id_set(self):
"""When user_id is provided, integration env vars are merged into sandbox envs."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value=env_vars),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=_USER,
)
mock_get_env.assert_awaited_once_with(_USER)
call_kwargs = sandbox.commands.run.call_args[1]
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
assert isinstance(result, BashExecResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_no_token_injection_when_user_id_is_none(self):
"""When user_id is None, get_integration_env_vars must NOT be called."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=None,
)
mock_get_env.assert_not_called()
call_kwargs = sandbox.commands.run.call_args[1]
assert "GH_TOKEN" not in call_kwargs["envs"]
assert isinstance(result, BashExecResponse)

View File

@@ -0,0 +1,215 @@
"""Tool for prompting the user to connect a required integration.
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
"authentication required"), it calls this tool to surface the credentials
setup card in the chat — the same UI that appears when a GitHub block runs
without configured credentials.
"""
import functools
from typing import Any, TypedDict
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import (
ErrorResponse,
ResponseType,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .base import BaseTool
class _ProviderInfo(TypedDict):
name: str
types: list[str]
# Default OAuth scopes requested when the agent doesn't specify any.
scopes: list[str]
class _CredentialEntry(TypedDict):
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
id: str
title: str
provider: str
provider_name: str
type: str
types: list[str]
scopes: list[str]
@functools.lru_cache(maxsize=1)
def _is_github_oauth_configured() -> bool:
"""Return True if GitHub OAuth env vars are set.
Evaluated lazily (not at import time) to avoid triggering Secrets() during
module import, which can fail in environments where secrets are not loaded.
"""
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
return GITHUB_OAUTH_IS_CONFIGURED
# Registry of known providers: name + supported credential types for the UI.
# When adding a new provider, also add its env var names to
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
def _get_provider_info() -> dict[str, _ProviderInfo]:
"""Build the provider registry, evaluating OAuth config lazily."""
return {
"github": {
"name": "GitHub",
"types": (
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
),
# Default: repo scope covers clone/push/pull for public and private repos.
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
"scopes": ["repo"],
},
}
class ConnectIntegrationTool(BaseTool):
"""Surface the credentials setup UI when an integration is not connected."""
@property
def name(self) -> str:
return "connect_integration"
@property
def description(self) -> str:
return (
"Prompt the user to connect a required integration (e.g. GitHub). "
"Call this when an external CLI or API call fails because the user "
"has not connected the relevant account. "
"The tool surfaces a credentials setup card in the chat so the user "
"can authenticate without leaving the page. "
"After the user connects the account, retry the operation. "
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
"automatically injected per-command in bash_exec — no manual export needed. "
"In local bubblewrap mode network is isolated so GitHub CLI commands "
"will still fail after connecting; inform the user of this limitation."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"provider": {
"type": "string",
"description": (
"Integration provider slug, e.g. 'github'. "
"Must be one of the supported providers."
),
"enum": list(_get_provider_info().keys()),
},
"reason": {
"type": "string",
"description": (
"Brief explanation of why the integration is needed, "
"shown to the user in the setup card."
),
"maxLength": 500,
},
"scopes": {
"type": "array",
"items": {"type": "string"},
"description": (
"OAuth scopes to request. Omit to use the provider default. "
"Add extra scopes when you need more access — e.g. for GitHub: "
"'repo' (clone/push/pull), 'read:org' (org membership), "
"'workflow' (GitHub Actions). "
"Requesting only the scopes you actually need is best practice."
),
},
},
"required": ["provider"],
}
@property
def requires_auth(self) -> bool:
# Require auth so only authenticated users can trigger the setup card.
# The card itself is user-agnostic (no per-user data needed), so
# user_id is intentionally unused in _execute.
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
session_id = session.session_id if session else None
provider: str = (kwargs.get("provider") or "").strip().lower()
reason: str = (kwargs.get("reason") or "").strip()[
:500
] # cap LLM-controlled text
extra_scopes: list[str] = [
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
]
provider_info = _get_provider_info()
info = provider_info.get(provider)
if not info:
supported = ", ".join(f"'{p}'" for p in provider_info)
return ErrorResponse(
message=(
f"Unknown provider '{provider}'. "
f"Supported providers: {supported}."
),
error="unknown_provider",
session_id=session_id,
)
provider_name: str = info["name"]
supported_types: list[str] = info["types"]
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
default_scopes: list[str] = info["scopes"]
seen: set[str] = set()
scopes: list[str] = []
for s in default_scopes + extra_scopes:
if s not in seen:
seen.add(s)
scopes.append(s)
field_key = f"{provider}_credentials"
message_parts = [
f"To continue, please connect your {provider_name} account.",
]
if reason:
message_parts.append(reason)
credential_entry: _CredentialEntry = {
"id": field_key,
"title": f"{provider_name} Credentials",
"provider": provider,
"provider_name": provider_name,
"type": supported_types[0],
"types": supported_types,
"scopes": scopes,
}
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
return SetupRequirementsResponse(
type=ResponseType.SETUP_REQUIREMENTS,
message=" ".join(message_parts),
session_id=session_id,
setup_info=SetupInfo(
agent_id=f"connect_{provider}",
agent_name=provider_name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_credentials,
ready_to_run=False,
),
requirements={
"credentials": [missing_credentials[field_key]],
"inputs": [],
"execution_modes": [],
},
),
)

View File

@@ -0,0 +1,135 @@
"""Tests for ConnectIntegrationTool."""
import pytest
from ._test_data import make_session
from .connect_integration import ConnectIntegrationTool
from .models import ErrorResponse, SetupRequirementsResponse
_TEST_USER_ID = "test-user-connect-integration"
class TestConnectIntegrationTool:
def _make_tool(self) -> ConnectIntegrationTool:
return ConnectIntegrationTool()
@pytest.mark.asyncio(loop_scope="session")
async def test_unknown_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
assert "nonexistent" in result.message
assert "github" in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_provider_returns_setup_response(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_name == "GitHub"
assert result.setup_info.agent_id == "connect_github"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_has_missing_credentials_in_readiness(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
readiness = result.setup_info.user_readiness
assert readiness.has_all_credentials is False
assert readiness.ready_to_run is False
assert "github_credentials" in readiness.missing_credentials
@pytest.mark.asyncio(loop_scope="session")
async def test_github_requirements_include_credential_entry(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
creds = result.setup_info.requirements["credentials"]
assert len(creds) == 1
assert creds[0]["provider"] == "github"
assert creds[0]["id"] == "github_credentials"
@pytest.mark.asyncio(loop_scope="session")
async def test_reason_appears_in_message(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
reason = "Needed to create a pull request."
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
)
assert isinstance(result, SetupRequirementsResponse)
assert reason in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_session_id_propagated(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.session_id == session.session_id
@pytest.mark.asyncio(loop_scope="session")
async def test_provider_case_insensitive(self):
"""Provider slug is normalised to lowercase before lookup."""
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="GitHub"
)
assert isinstance(result, SetupRequirementsResponse)
def test_tool_name(self):
assert ConnectIntegrationTool().name == "connect_integration"
def test_requires_auth(self):
assert ConnectIntegrationTool().requires_auth is True
@pytest.mark.asyncio(loop_scope="session")
async def test_unauthenticated_user_gets_need_login_response(self):
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
This verifies that the requires_auth guard in BaseTool.execute() fires
before _execute() is called, so unauthenticated callers cannot probe
which integrations are configured.
"""
import json
tool = self._make_tool()
# Session still needs a user_id string; the None is passed to execute()
# to simulate an unauthenticated call.
session = make_session(user_id=_TEST_USER_ID)
result = await tool.execute(
user_id=None,
session=session,
tool_call_id="test-call-id",
provider="github",
)
raw = result.output
output = json.loads(raw) if isinstance(raw, str) else raw
assert output.get("type") == "need_login"
assert result.success is False

View File

@@ -8,13 +8,11 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data.credit import UsageTransactionMetadata
from backend.data.db_accessors import credit_db, workspace_db
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.exceptions import BlockError
from backend.util.type import coerce_inputs_to_schema
from .models import BlockOutputResponse, ErrorResponse, ToolResponseBase
@@ -117,21 +115,6 @@ async def execute_block(
# Coerce non-matching data types to the expected input schema.
coerce_inputs_to_schema(input_data, block.input_schema)
# Pre-execution credit check (courtesy; spend_credits is atomic)
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
_credit_db = credit_db()
if has_cost:
balance = await _credit_db.get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
f"Insufficient credits to run '{block.name}'. "
"Please top up your credits to continue."
),
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
@@ -140,51 +123,6 @@ async def execute_block(
):
outputs[output_name].append(output_data)
# Charge credits for block execution
if has_cost:
try:
await _credit_db.spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=synthetic_graph_id,
graph_id=synthetic_graph_id,
node_id=synthetic_node_id,
node_exec_id=node_exec_id,
block_id=block_id,
block=block.name,
input=cost_filter,
reason="copilot_block_execution",
),
)
except Exception as e:
# Block already executed (with possible side effects). Never
# return ErrorResponse here — the user received output and
# deserves it. Log the billing failure for reconciliation.
leak_type = (
"INSUFFICIENT_BALANCE"
if isinstance(e, InsufficientBalanceError)
else "UNEXPECTED_ERROR"
)
logger.error(
"BILLING_LEAK[%s]: block executed but credit charge failed — "
"user_id=%s, block_id=%s, node_exec_id=%s, cost=%s: %s",
leak_type,
user_id,
block_id,
node_exec_id,
cost,
e,
extra={
"json_fields": {
"billing_leak": True,
"leak_type": leak_type,
"user_id": user_id,
"cost": str(cost),
}
},
)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
@@ -195,14 +133,14 @@ async def execute_block(
)
except BlockError as e:
logger.warning("Block execution failed: %s", e)
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error("Unexpected error executing block: %s", e, exc_info=True)
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),

View File

@@ -1,197 +1,18 @@
"""Tests for execute_block — credit charging and type coercion."""
"""Tests for execute_block type coercion in helpers.py.
Verifies that execute_block() coerces string input values to match the block's
expected input types, mirroring the executor's validate_exec() logic.
This is critical for @@agptfile: expansion, where file content is always a string
but the block may expect structured types (e.g. list[list[str]]).
"""
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from backend.copilot.tools.helpers import execute_block
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
_USER = "test-user-helpers"
_SESSION = "test-session-helpers"
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
"""Create a minimal mock block for execute_block()."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = BlockType.STANDARD
mock.input_schema = MagicMock()
mock.input_schema.get_credentials_fields_info.return_value = {}
async def _execute(
input_data: dict, **kwargs: Any
) -> AsyncIterator[tuple[str, Any]]:
yield "result", "ok"
mock.execute = _execute
return mock
def _patch_workspace():
"""Patch workspace_db to return a mock workspace."""
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_ws_db = MagicMock()
mock_ws_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
def _patch_credit_db(
get_credits_return: int = 100,
spend_credits_side_effect: Any = None,
):
"""Patch credit_db accessor to return a mock credit adapter."""
mock_credit = MagicMock()
mock_credit.get_credits = AsyncMock(return_value=get_credits_return)
if spend_credits_side_effect is not None:
mock_credit.spend_credits = AsyncMock(side_effect=spend_credits_side_effect)
else:
mock_credit.spend_credits = AsyncMock()
return (
patch(
"backend.copilot.tools.helpers.credit_db",
return_value=mock_credit,
),
mock_credit,
)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(get_credits_return=100)
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={"text": "hello"},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_credit.spend_credits.assert_awaited_once()
call_kwargs = mock_credit.spend_credits.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(get_credits_return=5)
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, ErrorResponse)
assert "Insufficient credits" in result.message
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db()
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_credit.get_credits.assert_not_awaited()
mock_credit.spend_credits.assert_not_awaited()
async def test_returns_output_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, output is still returned (block already ran)."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(
get_credits_return=15,
spend_credits_side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
),
)
with (
_patch_workspace(),
patch(
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
credit_patch,
):
result = await execute_block(
block=block,
block_id="block-1",
input_data={},
user_id=_USER,
session_id=_SESSION,
node_exec_id="exec-1",
matched_credentials={},
)
# Block already executed (with side effects), so output is returned
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# ---------------------------------------------------------------------------
# Type coercion tests
# ---------------------------------------------------------------------------
from backend.copilot.tools.models import BlockOutputResponse
def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
@@ -207,7 +28,7 @@ def _make_block_schema(annotations: dict[str, Any]) -> MagicMock:
return schema
def _make_coerce_block(
def _make_block(
block_id: str,
name: str,
annotations: dict[str, Any],
@@ -239,7 +60,7 @@ _TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
block = _make_coerce_block(
block = _make_block(
"sheets-write",
"Google Sheets Write",
{"values": list[list[str]], "spreadsheet_id": str},
@@ -282,7 +103,7 @@ async def test_coerce_json_string_to_nested_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string → list[str]."""
block = _make_coerce_block(
block = _make_block(
"list-block",
"List Block",
{"items": list[str]},
@@ -314,7 +135,7 @@ async def test_coerce_json_string_to_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string → dict[str, str]."""
block = _make_coerce_block(
block = _make_block(
"dict-block",
"Dict Block",
{"config": dict[str, str]},
@@ -346,7 +167,7 @@ async def test_coerce_json_string_to_dict():
@pytest.mark.asyncio(loop_scope="session")
async def test_no_coercion_when_type_matches():
"""Already-correct types pass through without coercion."""
block = _make_coerce_block(
block = _make_block(
"pass-through",
"Pass Through",
{"values": list[list[str]], "name": str},
@@ -380,7 +201,7 @@ async def test_no_coercion_when_type_matches():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number → int."""
block = _make_coerce_block(
block = _make_block(
"int-block",
"Int Block",
{"count": int},
@@ -413,7 +234,7 @@ async def test_coerce_string_to_int():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_skips_none_values():
"""None values are not coerced (they may be optional fields)."""
block = _make_coerce_block(
block = _make_block(
"optional-block",
"Optional Block",
{"data": list[str], "label": str},
@@ -446,7 +267,7 @@ async def test_coerce_skips_none_values():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_union_type_preserves_valid_member():
"""Union-typed fields should not be coerced when the value matches a member."""
block = _make_coerce_block(
block = _make_block(
"union-block",
"Union Block",
{"content": str | list[str]},
@@ -480,7 +301,7 @@ async def test_coerce_union_type_preserves_valid_member():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_inner_elements_of_generic():
"""Inner elements of generic containers are recursively coerced."""
block = _make_coerce_block(
block = _make_block(
"inner-coerce",
"Inner Coerce",
{"values": list[str]},

View File

@@ -129,16 +129,3 @@ def review_db():
review_db = get_database_manager_async_client()
return review_db
def credit_db():
if db.is_connected():
from backend.data import db_manager as _credit_db
credit_db = _credit_db
else:
from backend.util.clients import get_database_manager_async_client
credit_db = get_database_manager_async_client()
return credit_db

View File

@@ -148,11 +148,6 @@ async def _get_credits(user_id: str) -> int:
return await user_credit_model.get_credits(user_id)
# Public aliases used by db_accessors.credit_db() when Prisma is connected
get_credits = _get_credits
spend_credits = _spend_credits
class DatabaseManager(AppService):
"""Database connection pooling service.
@@ -517,10 +512,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Credits ============ #
spend_credits = d.spend_credits
get_credits = d.get_credits
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding

View File

@@ -25,6 +25,35 @@ logger = logging.getLogger(__name__)
settings = Settings()
_on_creds_changed: Callable[[str, str], None] | None = None
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
"""Register a callback invoked after any credential is created/updated/deleted.
The callback receives ``(user_id, provider)`` and should be idempotent.
Only one hook can be registered at a time; calling this again replaces the
previous hook. Intended to be called once at application startup by the
copilot module to bust its token cache without creating an import cycle.
"""
global _on_creds_changed
_on_creds_changed = hook
def _bust_copilot_cache(user_id: str, provider: str) -> None:
"""Invoke the registered hook (if any) to bust downstream token caches."""
if _on_creds_changed is not None:
try:
_on_creds_changed(user_id, provider)
except Exception:
logger.warning(
"Credential-change hook failed for user=%s provider=%s",
user_id,
provider,
exc_info=True,
)
class IntegrationCredentialsManager:
"""
Handles the lifecycle of integration credentials.
@@ -69,7 +98,11 @@ class IntegrationCredentialsManager:
return self._locks
async def create(self, user_id: str, credentials: Credentials) -> None:
return await self.store.add_creds(user_id, credentials)
result = await self.store.add_creds(user_id, credentials)
# Bust the copilot token cache so that the next bash_exec picks up the
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
_bust_copilot_cache(user_id, credentials.provider)
return result
async def exists(self, user_id: str, credentials_id: str) -> bool:
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
@@ -156,6 +189,8 @@ class IntegrationCredentialsManager:
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
# Bust copilot cache so the refreshed token is picked up immediately.
_bust_copilot_cache(user_id, fresh_credentials.provider)
if _lock and (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
@@ -168,10 +203,17 @@ class IntegrationCredentialsManager:
async def update(self, user_id: str, updated: Credentials) -> None:
async with self._locked(user_id, updated.id):
await self.store.update_creds(user_id, updated)
# Bust the copilot token cache so the updated credential is picked up immediately.
_bust_copilot_cache(user_id, updated.provider)
async def delete(self, user_id: str, credentials_id: str) -> None:
async with self._locked(user_id, credentials_id):
# Read inside the lock to avoid TOCTOU — another coroutine could
# delete the same credential between the read and the delete.
creds = await self.store.get_creds_by_id(user_id, credentials_id)
await self.store.delete_creds_by_id(user_id, credentials_id)
if creds:
_bust_copilot_cache(user_id, creds.provider)
# -- Locking utilities -- #

View File

@@ -1,8 +1,14 @@
"use client";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/molecules/DropdownMenu/DropdownMenu";
import { SidebarProvider } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { UploadSimple } from "@phosphor-icons/react";
import { DotsThree, UploadSimple } from "@phosphor-icons/react";
import { useCallback, useRef, useState } from "react";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
@@ -83,9 +89,10 @@ export function CopilotPage() {
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
// Delete functionality (available via ChatSidebar context menu on all viewports)
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
@@ -141,6 +148,38 @@ export function CopilotPage() {
isUploadingFiles={isUploadingFiles}
droppedFiles={droppedFiles}
onDroppedFilesConsumed={handleDroppedFilesConsumed}
headerSlot={
isMobile && sessionId ? (
<div className="flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className="rounded p-1.5 hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-5 w-5 text-neutral-600" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={() => {
const session = sessions.find(
(s) => s.id === sessionId,
);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
) : undefined
}
/>
</div>
</div>

View File

@@ -2,6 +2,7 @@
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ReactNode } from "react";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
@@ -20,6 +21,7 @@ export interface ChatContainerProps {
onSend: (message: string, files?: File[]) => void | Promise<void>;
onStop: () => void;
isUploadingFiles?: boolean;
headerSlot?: ReactNode;
/** Files dropped onto the chat window. */
droppedFiles?: File[];
/** Called after droppedFiles have been consumed by ChatInput. */
@@ -38,6 +40,7 @@ export const ChatContainer = ({
onSend,
onStop,
isUploadingFiles,
headerSlot,
droppedFiles,
onDroppedFilesConsumed,
}: ChatContainerProps) => {
@@ -60,6 +63,7 @@ export const ChatContainer = ({
status={status}
error={error}
isLoading={isLoadingSession}
headerSlot={headerSlot}
sessionID={sessionId}
/>
<motion.div

View File

@@ -30,6 +30,7 @@ interface Props {
status: string;
error: Error | undefined;
isLoading: boolean;
headerSlot?: React.ReactNode;
sessionID?: string | null;
}
@@ -101,6 +102,7 @@ export function ChatMessagesContainer({
status,
error,
isLoading,
headerSlot,
sessionID,
}: Props) {
const lastMessage = messages[messages.length - 1];
@@ -133,6 +135,7 @@ export function ChatMessagesContainer({
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
{headerSlot}
{isLoading && messages.length === 0 && (
<div
className="flex flex-1 items-center justify-center"

View File

@@ -3,6 +3,7 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { ExclamationMarkIcon } from "@phosphor-icons/react";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useState } from "react";
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
import {
@@ -129,6 +130,8 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
case "tool-search_docs":
case "tool-get_doc_page":
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
case "tool-connect_integration":
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
case "tool-run_block":
case "tool-continue_run_block":
return <RunBlockTool key={key} part={part as ToolUIPart} />;

View File

@@ -37,7 +37,6 @@ import { useCopilotUIStore } from "../../store";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
import { UsageLimits } from "../UsageLimits/UsageLimits";
export function ChatSidebar() {
const { state } = useSidebar();
@@ -257,10 +256,11 @@ export function ChatSidebar() {
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="flex items-center">
<UsageLimits />
<div className="relative left-5 flex items-center gap-1">
<NotificationToggle />
<SidebarTrigger />
<div className="relative left-1">
<SidebarTrigger />
</div>
</div>
</div>
{sessionId ? (

View File

@@ -7,7 +7,6 @@ import {
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { toast } from "@/components/molecules/Toast/use-toast";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
import { Bell, BellRinging, BellSlash } from "@phosphor-icons/react";
import { useCopilotUIStore } from "../../../../store";
@@ -49,7 +48,10 @@ export function NotificationToggle() {
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Notification settings">
<button
className="rounded p-1 text-black transition-colors hover:bg-zinc-50"
aria-label="Notification settings"
>
{!isNotificationsEnabled ? (
<BellSlash className="!size-5" />
) : isSoundEnabled ? (
@@ -57,7 +59,7 @@ export function NotificationToggle() {
) : (
<Bell className="!size-5" />
)}
</Button>
</button>
</PopoverTrigger>
<PopoverContent align="start" className="w-56 p-3">
<div className="flex flex-col gap-3">

View File

@@ -1,38 +0,0 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/molecules/Popover/Popover";
import { Button } from "@/components/ui/button";
import { ChartBar } from "@phosphor-icons/react";
import { UsagePanelContent } from "./UsagePanelContent";
export { UsagePanelContent, formatResetTime } from "./UsagePanelContent";
export function UsageLimits() {
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="icon" aria-label="Usage limits">
<ChartBar className="!size-5" weight="light" />
</Button>
</PopoverTrigger>
<PopoverContent align="start" className="w-64 p-3">
<UsagePanelContent usage={usage} />
</PopoverContent>
</Popover>
);
}

View File

@@ -1,118 +0,0 @@
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import Link from "next/link";
export function formatResetTime(
resetsAt: Date | string,
now: Date = new Date(),
): string {
const resetDate =
typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt;
const diffMs = resetDate.getTime() - now.getTime();
if (diffMs <= 0) return "now";
const hours = Math.floor(diffMs / (1000 * 60 * 60));
// Under 24h: show relative time ("in 4h 23m")
if (hours < 24) {
const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60));
if (hours > 0) return `in ${hours}h ${minutes}m`;
return `in ${minutes}m`;
}
// Over 24h: show day and time in local timezone ("Mon 12:00 AM PST")
return resetDate.toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
}
function UsageBar({
label,
used,
limit,
resetsAt,
}: {
label: string;
used: number;
limit: number;
resetsAt: Date | string;
}) {
if (limit <= 0) return null;
const rawPercent = (used / limit) * 100;
const percent = Math.min(100, Math.round(rawPercent));
const isHigh = percent >= 80;
const percentLabel =
used > 0 && percent === 0 ? "<1% used" : `${percent}% used`;
return (
<div className="flex flex-col gap-1">
<div className="flex items-baseline justify-between">
<span className="text-xs font-medium text-neutral-700">{label}</span>
<span className="text-[11px] tabular-nums text-neutral-500">
{percentLabel}
</span>
</div>
<div className="text-[10px] text-neutral-400">
Resets {formatResetTime(resetsAt)}
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className={`h-full rounded-full transition-[width] duration-300 ease-out ${
isHigh ? "bg-orange-500" : "bg-blue-500"
}`}
style={{ width: `${Math.max(used > 0 ? 1 : 0, percent)}%` }}
/>
</div>
</div>
);
}
export function UsagePanelContent({
usage,
showBillingLink = true,
}: {
usage: CoPilotUsageStatus;
showBillingLink?: boolean;
}) {
const hasDailyLimit = usage.daily.limit > 0;
const hasWeeklyLimit = usage.weekly.limit > 0;
if (!hasDailyLimit && !hasWeeklyLimit) {
return (
<div className="text-xs text-neutral-500">No usage limits configured</div>
);
}
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
{hasDailyLimit && (
<UsageBar
label="Today"
used={usage.daily.used}
limit={usage.daily.limit}
resetsAt={usage.daily.resets_at}
/>
)}
{hasWeeklyLimit && (
<UsageBar
label="This week"
used={usage.weekly.used}
limit={usage.weekly.limit}
resetsAt={usage.weekly.resets_at}
/>
)}
{showBillingLink && (
<Link
href="/profile/credits"
className="text-[11px] text-blue-600 hover:underline"
>
Learn more about usage limits
</Link>
)}
</div>
);
}

View File

@@ -1,124 +0,0 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsageLimits } from "../UsageLimits";
// Mock the generated Orval hook
const mockUseGetV2GetCopilotUsage = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useGetV2GetCopilotUsage: (opts: unknown) => mockUseGetV2GetCopilotUsage(opts),
}));
// Mock Popover to render children directly (Radix portals don't work in happy-dom)
vi.mock("@/components/molecules/Popover/Popover", () => ({
Popover: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverTrigger: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PopoverContent: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
afterEach(() => {
cleanup();
mockUseGetV2GetCopilotUsage.mockReset();
});
function makeUsage({
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
};
}
describe("UsageLimits", () => {
it("renders nothing while loading", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: undefined,
isLoading: true,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders nothing when no limits are configured", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }),
isLoading: false,
});
const { container } = render(<UsageLimits />);
expect(container.innerHTML).toBe("");
});
it("renders the usage button when limits exist", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined();
});
it("displays daily and weekly usage percentages", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("50% used")).toBeDefined();
expect(screen.getByText("Today")).toBeDefined();
expect(screen.getByText("This week")).toBeDefined();
expect(screen.getByText("Usage limits")).toBeDefined();
});
it("shows only weekly bar when daily limit is 0", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({
dailyLimit: 0,
weeklyUsed: 25000,
weeklyLimit: 50000,
}),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("This week")).toBeDefined();
expect(screen.queryByText("Today")).toBeNull();
});
it("caps percentage at 100% when over limit", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("100% used")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage(),
isLoading: false,
});
render(<UsageLimits />);
const link = screen.getByText("Learn more about usage limits");
expect(link).toBeDefined();
expect(link.closest("a")?.getAttribute("href")).toBe("/profile/credits");
});
});

View File

@@ -0,0 +1,104 @@
"use client";
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
import type { ToolUIPart } from "ai";
import { useState } from "react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
type Props = {
part: ToolUIPart;
};
function parseJson(raw: unknown): unknown {
if (typeof raw === "string") {
try {
return JSON.parse(raw);
} catch {
return null;
}
}
return raw;
}
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
return parsed as SetupRequirementsResponse;
}
return null;
}
function parseError(raw: unknown): string | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "message" in parsed) {
return String((parsed as { message: unknown }).message);
}
return null;
}
export function ConnectIntegrationTool({ part }: Props) {
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
const [isDismissed, setIsDismissed] = useState(false);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
const output =
part.state === "output-available"
? parseOutput((part as { output?: unknown }).output)
: null;
const errorMessage = isError
? (parseError((part as { output?: unknown }).output) ??
"Failed to connect integration")
: null;
const rawProvider =
(part as { input?: { provider?: string } }).input?.provider ?? "";
const providerName =
output?.setup_info?.agent_name ??
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
// prevent runaway text in the DOM.
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
const label = isStreaming
? `Connecting ${providerName}`
: isError
? `Failed to connect ${providerName}`
: output
? `Connect ${output.setup_info?.agent_name ?? providerName}`
: `Connect ${providerName}`;
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<MorphingTextAnimation
text={label}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isError && errorMessage && (
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
)}
{output && (
<div className="mt-2">
{isDismissed ? (
<ContentMessage>Connected. Continuing</ContentMessage>
) : (
<SetupRequirementsCard
output={output}
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
retryInstruction="I've connected my account. Please continue."
onComplete={() => setIsDismissed(true)}
/>
)}
</div>
)}
</div>
);
}

View File

@@ -23,12 +23,16 @@ interface Props {
/** Override the label shown above the credentials section.
* Defaults to "Credentials". */
credentialsLabel?: string;
/** Called after Proceed is clicked so the parent can persist the dismissed state
* across remounts (avoids re-enabling the Proceed button on remount). */
onComplete?: () => void;
}
export function SetupRequirementsCard({
output,
retryInstruction,
credentialsLabel,
onComplete,
}: Props) {
const { onSend } = useCopilotChatActions();
@@ -68,13 +72,17 @@ export function SetupRequirementsCard({
return v !== undefined && v !== null && v !== "";
});
if (hasSent) {
return <ContentMessage>Connected. Continuing</ContentMessage>;
}
const canRun =
!hasSent &&
(!needsCredentials || isAllCredentialsComplete) &&
(!needsInputs || isAllInputsComplete);
function handleRun() {
setHasSent(true);
onComplete?.();
const parts: string[] = [];
if (needsCredentials) {

View File

@@ -1,5 +1,4 @@
import {
getGetV2GetCopilotUsageQueryKey,
getGetV2GetSessionQueryKey,
postV2CancelSessionTask,
} from "@/app/api/__generated__/endpoints/chat/chat";
@@ -178,41 +177,12 @@ export function useCopilotStream({
onError: (error) => {
if (!sessionId) return;
// Detect rate limit (429) responses and show reset time to the user.
// The SDK throws a plain Error whose message is the raw response body
// (FastAPI returns {"detail": "...usage limit..."} for 429s).
let errorDetail: string = error.message;
try {
const parsed = JSON.parse(error.message) as unknown;
if (
typeof parsed === "object" &&
parsed !== null &&
"detail" in parsed &&
typeof (parsed as { detail: unknown }).detail === "string"
) {
errorDetail = (parsed as { detail: string }).detail;
}
} catch {
// Not JSON — use message as-is
}
const isRateLimited = errorDetail.toLowerCase().includes("usage limit");
if (isRateLimited) {
toast({
title: "Usage limit reached",
description:
errorDetail ||
"You've reached your usage limit. Please try again later.",
variant: "destructive",
});
return;
}
// Detect authentication failures (from getAuthHeaders or 401 responses)
const isAuthError =
errorDetail.includes("Authentication failed") ||
errorDetail.includes("Unauthorized") ||
errorDetail.includes("Not authenticated") ||
errorDetail.toLowerCase().includes("401");
error.message.includes("Authentication failed") ||
error.message.includes("Unauthorized") ||
error.message.includes("Not authenticated") ||
error.message.toLowerCase().includes("401");
if (isAuthError) {
toast({
title: "Authentication error",
@@ -337,9 +307,6 @@ export function useCopilotStream({
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
queryClient.invalidateQueries({
queryKey: getGetV2GetCopilotUsageQueryKey(),
});
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;

View File

@@ -11,9 +11,6 @@ import {
import { RefundModal } from "./RefundModal";
import { CreditTransaction } from "@/lib/autogpt-server-api";
import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat";
import {
Table,
@@ -24,32 +21,6 @@ import {
TableRow,
} from "@/components/__legacy__/ui/table";
function CoPilotUsageSection() {
const router = useRouter();
const { data: usage, isLoading } = useGetV2GetCopilotUsage({
query: {
select: (res) => res.data as CoPilotUsageStatus,
refetchInterval: 30000,
staleTime: 10000,
},
});
if (isLoading || !usage) return null;
if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null;
return (
<div className="my-6 space-y-4">
<h3 className="text-lg font-medium">CoPilot Usage Limits</h3>
<div className="rounded-lg border border-neutral-200 p-4">
<UsagePanelContent usage={usage} showBillingLink={false} />
</div>
<Button className="w-full" onClick={() => router.push("/copilot")}>
Open CoPilot
</Button>
</div>
);
}
export default function CreditsPage() {
const api = useBackendAPI();
const {
@@ -266,13 +237,11 @@ export default function CreditsPage() {
</Button>
)}
</form>
{/* CoPilot Usage Limits */}
<CoPilotUsageSection />
</div>
<div className="my-6 space-y-4">
{/* Payment Portal */}
<h3 className="text-lg font-medium">Manage Your Payment Methods</h3>
<p className="text-neutral-600">
You can manage your cards and see your payment history in the

View File

@@ -1267,7 +1267,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to a per-turn Redis stream for reconnection support. If the client\ndisconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1382,28 +1382,6 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/usage": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Copilot Usage",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.",
"operationId": "getV2GetCopilotUsage",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8477,16 +8455,6 @@
"title": "ClarifyingQuestion",
"description": "A question that needs user clarification."
},
"CoPilotUsageStatus": {
"properties": {
"daily": { "$ref": "#/components/schemas/UsageWindow" },
"weekly": { "$ref": "#/components/schemas/UsageWindow" }
},
"type": "object",
"required": ["daily", "weekly"],
"title": "CoPilotUsageStatus",
"description": "Current usage status for a user across all windows."
},
"ContentType": {
"type": "string",
"enum": [
@@ -12222,16 +12190,6 @@
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
},
"total_prompt_tokens": {
"type": "integer",
"title": "Total Prompt Tokens",
"default": 0
},
"total_completion_tokens": {
"type": "integer",
"title": "Total Completion Tokens",
"default": 0
}
},
"type": "object",
@@ -14629,25 +14587,6 @@
"required": ["timezone"],
"title": "UpdateTimezoneRequest"
},
"UsageWindow": {
"properties": {
"used": { "type": "integer", "title": "Used" },
"limit": {
"type": "integer",
"title": "Limit",
"description": "Maximum tokens allowed in this window. 0 means unlimited."
},
"resets_at": {
"type": "string",
"format": "date-time",
"title": "Resets At"
}
},
"type": "object",
"required": ["used", "limit", "resets_at"],
"title": "UsageWindow",
"description": "Usage within a single time window."
},
"UserHistoryResponse": {
"properties": {
"history": {

View File

@@ -125,9 +125,9 @@ export function useCredentialsInput({
if (hasAttemptedAutoSelect.current) return;
hasAttemptedAutoSelect.current = true;
// Auto-select if exactly one credential matches.
// For optional fields with multiple options, let the user choose.
if (isOptional && savedCreds.length > 1) return;
// Auto-select only when there is exactly one saved credential.
// With multiple options the user must choose — regardless of optional/required.
if (savedCreds.length > 1) return;
const cred = savedCreds[0];
onSelectCredential({

View File

@@ -288,7 +288,6 @@ const SidebarTrigger = React.forwardRef<
ref={ref}
data-sidebar="trigger"
variant="ghost"
size="icon"
onClick={(event) => {
onClick?.(event);
toggleSidebar();