Compare commits

..

6 Commits

Author SHA1 Message Date
Nicholas Tindle
8892bcd230 docs: Add workspace and media file architecture documentation (#11989)
### Changes 🏗️

- Added comprehensive architecture documentation at
`docs/platform/workspace-media-architecture.md` covering:
  - Database models (`UserWorkspace`, `UserWorkspaceFile`)
  - `WorkspaceManager` API with session scoping
- `store_media_file()` media normalization pipeline (input types, return
formats)
  - Virus scanning responsibility boundaries
- Decision tree for choosing `WorkspaceManager` vs `store_media_file()`
- Configuration reference including `clamav_max_concurrency` and
`clamav_mark_failed_scans_as_clean`
  - Common patterns with error handling examples
- Updated `autogpt_platform/backend/CLAUDE.md` with a "Workspace & Media
Files" section referencing the new docs
- Removed duplicate `scan_content_safe()` call from
`WriteWorkspaceFileTool` — `WorkspaceManager.write_file()` already scans
internally, so the tool was double-scanning every file
- Replaced removed comment in `workspace.py` with explicit ownership
comment clarifying that `WorkspaceManager` is the single scanning
boundary

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified `scan_content_safe()` is called inside
`WorkspaceManager.write_file()` (workspace.py:186)
- [x] Verified `store_media_file()` scans all input branches including
local paths (file.py:351)
- [x] Verified documentation accuracy against current source code after
merge with dev
  - [x] CI checks all passing

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Mostly adds documentation and internal developer guidance; the only
code change is a comment clarifying `WorkspaceManager.write_file()` as
the single virus-scanning boundary, with no behavior change.
> 
> **Overview**
> Adds a new `docs/platform/workspace-media-architecture.md` describing
the Workspace storage layer vs the `store_media_file()` media pipeline,
including session scoping and virus-scanning/persistence responsibility
boundaries.
> 
> Updates backend `CLAUDE.md` to point contributors to the new doc when
working on CoPilot uploads/downloads or
`WorkspaceManager`/`store_media_file()`, and clarifies in
`WorkspaceManager.write_file()` (comment-only) that callers should not
duplicate virus scanning.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
18fcfa03f8. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 06:12:26 +00:00
Zamil Majdy
48ff8300a4 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-03-17 13:13:42 +07:00
Abhimanyu Yadav
c268fc6464 test(frontend/builder): add integration tests for builder stores, components, and hooks (part-1) (#12433)
### Changes
- Add 329 integration tests across 11 test files for the builder (visual
  workflow editor)
- Cover all Zustand stores (nodeStore, edgeStore, historyStore,
graphStore,
  copyPasteStore, blockMenuStore, controlPanelStore)
- Cover key components (CustomNode, NewBlockMenu, NewSaveControl,
RunGraph)
- Cover hooks (useFlow, useCopyPaste)

### Test files

  | File | Tests | Coverage |
  |------|-------|----------|
| `nodeStore.test.ts` | 58 | Node lifecycle, bulk ops, backend
conversion,
  execution tracking, status, errors, resolution mode |
  | `edgeStore.test.ts` | 37 | Edge CRUD, duplicate rejection, bead
  visualization, backend link conversion, upsert |
| `historyStore.test.ts` | 22 | Undo/redo, history limits (50),
microtask
  batching, deduplication, canUndo/canRedo |
| `graphStore.test.ts` | 28 | Execution status transitions,
isGraphRunning,
  schema management, sub-graphs |
| `copyPasteStore.test.ts` | 8 | Copy/paste with ID remapping, position
offset,
   edge preservation |
| `CustomNode.test.tsx` | 25 | Rendering by block type (NOTE, WEBHOOK,
AGENT,
  OUTPUT, AYRSHARE), error states |
| `NewBlockMenu.test.tsx` | 29 | Store state (search, filters, creators,
  categories), search/default view routing |
| `NewSaveControl.test.tsx` | 11 | Save dialog rendering, form
validation,
  version display, popover state |
| `RunGraph.test.tsx` | 11 | Run/stop button states, loading, click
handlers,
  RunInputDialog visibility |
  | `useFlow.test.ts` | 4 | Loading states, initial load completion |
| `useCopyPaste.test.ts` | 16 | Clipboard copy/paste, UUID remapping,
viewport
  centering, input field guard |
2026-03-17 05:24:55 +00:00
Reinier van der Leer
aff3fb44af ci(platform): Improve end-to-end CI & reduce its cost (#12437)
Our CI costs are skyrocketing, most of it because of
`platform-fullstack-ci.yml`. The `types` job currently uses in a
`big-boi` runner (= expensive), but doesn't need to.
Additionally, the "end-to-end tests" job is currently in
`platform-frontend-ci.yml` instead of `platform-fullstack-ci.yml`,
causing it not to run on backend changes (which it should).

### Changes 🏗️

- Simplify `check-api-types` job (renamed from `types`) and make it use
regular `ubuntu-latest` runner
- Export API schema from backend through CLI (instead of spinning it up
in docker)
- Fix dependency caching in `platform-fullstack-ci.yml` (based on recent
improvements in `platform-frontend-ci.yml`)
- Move `e2e_tests` job to `platform-fullstack-ci.yml`

Out-of-scope but necessary:
- Eliminate module-level init of OpenAI client in
`backend.copilot.service`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI
2026-03-16 23:08:18 +00:00
Zamil Majdy
9a41312769 feat(backend/copilot): parse @@agptfile bare refs by file extension (#12392)
The `@@agptfile:` expansion system previously used content-sniffing
(trying
`json.loads` then `csv.Sniffer`) to decide whether to parse file content
as
structured data. This was fragile — a file containing just `"42"` would
be
parsed as an integer, and the heuristics could misfire on ambiguous
content.

This PR replaces content-sniffing with **extension/MIME-based format
detection**.
When the file has a well-known extension (`.json`, `.csv`, etc.) or MIME
type
fragment (`workspace://id#application/json`), the content is parsed
accordingly.
Unknown formats or parse failures always fall back to plain string — no
surprises.

> [!NOTE]
> This PR builds on the `@@agptfile:` file reference protocol introduced
in #12332 and the structured data auto-parsing added in #12390.
>
> **What is `@@agptfile:`?**
> It is a special URI prefix (e.g. `@@agptfile:workspace:///report.csv`)
that the CoPilot SDK expands inline before sending tool arguments to
blocks. This lets the AI reference workspace files by name, and the SDK
automatically reads and injects the file content. See #12332 for the
full design.

### Changes 🏗️

**New utility: `backend/util/file_content_parser.py`**
- `infer_format(uri)` — determines format from file extension or MIME
fragment
- `parse_file_content(content, fmt)` — parses content, never raises
- Supported text formats: JSON, JSONL/NDJSON, CSV, TSV, YAML, TOML
- Supported binary formats: Parquet (via pyarrow), Excel/XLSX (via
openpyxl)
- JSON scalars (strings, numbers, booleans, null) stay as strings — only
  containers (arrays, objects) are promoted
- CSV/TSV require ≥1 row and ≥2 columns to qualify as tabular data
- Added `openpyxl` dependency for Excel reading via pandas
- Case-insensitive MIME fragment matching per RFC 2045
- Shared `PARSE_EXCEPTIONS` constant to avoid duplication between
modules

**Updated `expand_file_refs_in_args` in `file_ref.py`**
- Bare refs now use `infer_format` + `parse_file_content` instead of the
  old `_try_parse_structured` content-sniffing function
- Binary formats (parquet, xlsx) read raw bytes via `read_file_bytes`
- Embedded refs (text around `@@agptfile:`) still produce plain strings
- **Size guards**: Workspace and sandbox file reads now enforce a 10 MB
limit
  (matching the existing local file limit) to prevent OOM on large files

**Updated `blocks/github/commits.py`**
- Consolidated `_create_blob` and `_create_binary_blob` into a single
function
  with an `encoding` parameter

**Updated copilot system prompt**
- Documents the extension-based structured data parsing and supported
formats

**66 new tests** in `file_content_parser_test.py` covering:
- Format inference (extension, MIME, case-insensitive, precedence)
- All 8 format parsers (happy path + edge cases + fallbacks)
- Binary format handling (string input fallback, invalid bytes fallback)
- Unknown format passthrough

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All 66 file_content_parser_test.py tests pass
  - [x] All 31 file_ref_test.py tests pass
  - [x] All 13 file_ref_integration_test.py tests pass
  - [x] `poetry run format` passes clean (including pyright)
2026-03-16 22:31:21 +00:00
Otto
0b594a219c feat(copilot): support prompt-in-URL for shareable prompt links (#12406)
Requested by @torantula

Add support for shareable AutoPilot URLs that contain a prompt in the
URL hash fragment, inspired by [Lovable's
implementation](https://docs.lovable.dev/integrations/build-with-url).

**URL format:**
- `/copilot#prompt=URL-encoded-text` — pre-fills the input for the user
to review before sending
- `/copilot?autosubmit=true#prompt=...` — auto-creates a session and
sends the prompt immediately

**Example:**
```
https://platform.agpt.co/copilot#prompt=Create%20a%20todo%20app
https://platform.agpt.co/copilot?autosubmit=true#prompt=Create%20a%20todo%20app
```

**Key design decisions:**
- Uses URL fragment (`#`) instead of query params — fragments never hit
the server, so prompts stay client-side only (better for privacy, no
backend URL length limits)
- URL is cleaned via `history.replaceState` immediately after extraction
to prevent re-triggering on navigation/reload
- Leverages existing `pendingMessage` + `createSession()` flow for
auto-submit — no new backend APIs needed
- For populate-only mode, passes `initialPrompt` down through component
tree to pre-fill the chat input

**Files changed:**
- `useCopilotPage.ts` — URL hash extraction logic + `initialPrompt`
state
- `CopilotPage.tsx` — passes `initialPrompt` to `ChatContainer`
- `ChatContainer.tsx` — passes `initialPrompt` to `EmptySession`
- `EmptySession.tsx` — passes `initialPrompt` to `ChatInput`
- `ChatInput.tsx` / `useChatInput.ts` — accepts `initialValue` to
pre-fill the textarea

Fixes SECRT-2119

---
Co-authored-by: Toran Bruce Richards (@Torantulino) <toran@agpt.co>
2026-03-13 23:54:54 +07:00
99 changed files with 10007 additions and 7457 deletions

View File

@@ -5,12 +5,14 @@ on:
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
pull_request:
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/backend/**"
- "autogpt_platform/autogpt_libs/**"
merge_group:

View File

@@ -120,175 +120,6 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs
integration_test:
runs-on: ubuntu-latest
needs: setup

View File

@@ -1,14 +1,18 @@
name: AutoGPT Platform - Frontend CI
name: AutoGPT Platform - Full-stack CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- ".github/workflows/scripts/docker-ci-fix-compose-build-cache.py"
- ".github/workflows/scripts/get_package_version_from_lockfile.py"
- "autogpt_platform/**"
merge_group:
@@ -24,42 +28,28 @@ defaults:
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
- name: Set up Node
uses: actions/setup-node@v6
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Install dependencies
- name: Install dependencies to populate cache
run: pnpm install --frozen-lockfile
types:
runs-on: big-boi
check-api-types:
name: check API types
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
@@ -67,70 +57,256 @@ jobs:
with:
submodules: recursive
- name: Set up Node.js
# ------------------------ Backend setup ------------------------
- name: Set up Backend - Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Set up Backend - Install Poetry
working-directory: autogpt_platform/backend
run: |
POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Installing Poetry version ${POETRY_VERSION}"
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
- name: Set up Backend - Set up dependency cache
uses: actions/cache@v5
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Set up Backend - Install dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Set up Backend - Generate Prisma client
working-directory: autogpt_platform/backend
run: poetry run prisma generate && poetry run gen-prisma-stub
- name: Set up Frontend - Export OpenAPI schema from Backend
working-directory: autogpt_platform/backend
run: poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
# ------------------------ Frontend setup ------------------------
- name: Set up Frontend - Enable corepack
run: corepack enable
- name: Set up Frontend - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
- name: Restore dependencies cache
uses: actions/cache@v5
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
- name: Set up Frontend - Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Set up Frontend - Format OpenAPI schema
id: format-schema
run: pnpm prettier --write ./src/app/api/openapi.json
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "The openapi.json file has been modified after exporting the API schema."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo "\nIn the backend directory:"
echo "1. Run 'poetry run export-api-schema --output ../frontend/src/app/api/openapi.json'"
echo "\nIn the frontend directory:"
echo "2. Run 'pnpm prettier --write src/app/api/openapi.json'"
echo "3. Run 'pnpm generate:api'"
echo "4. Run 'pnpm types'"
echo "5. Fix any TypeScript errors that may have been introduced"
echo "6. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Run Typescript checks
- name: Set up Frontend - Generate API client
id: generate-api-client
run: pnpm orval --config ./orval.config.ts
# Continue with type generation & check even if there are schema changes
if: success() || (steps.format-schema.outcome == 'success')
- name: Check for TypeScript errors
run: pnpm types
if: success() || (steps.generate-api-client.outcome == 'success')
e2e_test:
name: end-to-end tests
runs-on: big-boi
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
submodules: recursive
- name: Set up Platform - Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Set up Platform - Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Platform - Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver: docker-container
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform
run: |
pip install pyyaml
# Resolve extends and generate a flat compose file that bake can understand
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
# Add cache configuration to the resolved compose file
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
--source docker-compose.resolved.yml \
--cache-from "type=gha" \
--cache-to "type=gha,mode=max" \
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend/**') }}" \
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src/**') }}" \
--git-ref "${{ github.ref }}"
# Build with bake using the resolved compose file (now includes cache config)
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Cache E2E test data
id: e2e-data-cache
uses: actions/cache@v5
with:
path: /tmp/e2e_test_data.sql
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-fullstack-ci.yml') }}
- name: Set up Platform - Start Supabase DB + Auth
run: |
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
echo "Waiting for auth service to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
- name: Set up Platform - Run migrations
run: |
echo "Running migrations..."
docker compose -f ../docker-compose.resolved.yml run --rm migrate
echo "✅ Migrations completed"
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Load cached E2E test data
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
run: |
echo "✅ Found cached E2E test data, restoring..."
{
echo "SET session_replication_role = 'replica';"
cat /tmp/e2e_test_data.sql
echo "SET session_replication_role = 'origin';"
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
# Refresh materialized views after restore
docker compose -f ../docker-compose.resolved.yml exec -T db \
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
echo "✅ E2E test data restored from cache"
- name: Set up Platform - Start (all other services)
run: |
docker compose -f ../docker-compose.resolved.yml up -d --no-build
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
env:
NEXT_PUBLIC_PW_TEST: true
- name: Set up tests - Create E2E test data
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
run: |
echo "Creating E2E test data..."
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
exit 1
}
# Dump auth.users + platform schema for cache (two separate dumps)
echo "Dumping database for cache..."
{
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--table='auth.users' postgres
docker compose -f ../docker-compose.resolved.yml exec -T db \
pg_dump -U postgres --data-only --column-inserts \
--schema=platform \
--exclude-table='platform._prisma_migrations' \
--exclude-table='platform.apscheduler_jobs' \
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
postgres
} > /tmp/e2e_test_data.sql
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
- name: Set up tests - Enable corepack
run: corepack enable
- name: Set up tests - Set up Node
uses: actions/setup-node@v6
with:
node-version: "22.18.0"
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Install dependencies
run: pnpm install --frozen-lockfile
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.resolved.yml logs

View File

@@ -178,6 +178,16 @@ yield "image_url", result_url
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Workspace & Media Files
**Read [Workspace & Media Architecture](../../docs/platform/workspace-media-architecture.md) when:**
- Working on CoPilot file upload/download features
- Building blocks that handle `MediaFileType` inputs/outputs
- Modifying `WorkspaceManager` or `store_media_file()`
- Debugging file persistence or virus scanning issues
Covers: `WorkspaceManager` (persistent storage with session scoping), `store_media_file()` (media normalization pipeline), and responsibility boundaries for virus scanning and persistence.
## Security Implementation
### Cache Protection Middleware

View File

@@ -6,13 +6,11 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
import prisma.enums
from pydantic import BaseModel, EmailStr
from backend.copilot.session_types import ChatSessionStartType
from backend.data.model import UserTransaction
from backend.util.models import Pagination
if TYPE_CHECKING:
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
from backend.data.model import User
class UserHistoryResponse(BaseModel):
@@ -92,51 +90,3 @@ class BulkInvitedUsersResponse(BaseModel):
for row in result.results
],
)
class AdminCopilotUserSummary(BaseModel):
id: str
email: str
name: Optional[str] = None
timezone: str
created_at: datetime
updated_at: datetime
@classmethod
def from_user(cls, user: "User") -> "AdminCopilotUserSummary":
return cls(
id=user.id,
email=user.email,
name=user.name,
timezone=user.timezone,
created_at=user.created_at,
updated_at=user.updated_at,
)
class AdminCopilotUsersResponse(BaseModel):
users: list[AdminCopilotUserSummary]
class TriggerCopilotSessionRequest(BaseModel):
user_id: str
start_type: ChatSessionStartType
class TriggerCopilotSessionResponse(BaseModel):
session_id: str
start_type: ChatSessionStartType
class SendCopilotEmailsRequest(BaseModel):
user_id: str
class SendCopilotEmailsResponse(BaseModel):
candidate_count: int
processed_count: int
sent_count: int
skipped_count: int
repair_queued_count: int
running_count: int
failed_count: int

View File

@@ -2,12 +2,8 @@ import logging
import math
from autogpt_libs.auth import get_user_id, requires_admin_user
from fastapi import APIRouter, File, HTTPException, Query, Security, UploadFile
from fastapi import APIRouter, File, Query, Security, UploadFile
from backend.copilot.autopilot import (
send_pending_copilot_emails_for_user,
trigger_autopilot_session_for_user,
)
from backend.data.invited_user import (
bulk_create_invited_users_from_file,
create_invited_user,
@@ -16,20 +12,13 @@ from backend.data.invited_user import (
revoke_invited_user,
)
from backend.data.tally import mask_email
from backend.data.user import search_users
from backend.util.models import Pagination
from .model import (
AdminCopilotUsersResponse,
AdminCopilotUserSummary,
BulkInvitedUsersResponse,
CreateInvitedUserRequest,
InvitedUserResponse,
InvitedUsersResponse,
SendCopilotEmailsRequest,
SendCopilotEmailsResponse,
TriggerCopilotSessionRequest,
TriggerCopilotSessionResponse,
)
logger = logging.getLogger(__name__)
@@ -146,95 +135,3 @@ async def retry_invited_user_tally_route(
invited_user_id,
)
return InvitedUserResponse.from_record(invited_user)
@router.get(
"/copilot/users",
response_model=AdminCopilotUsersResponse,
summary="Search Copilot Users",
operation_id="getV2SearchCopilotUsers",
)
async def search_copilot_users_route(
search: str = Query("", description="Search by email, name, or user ID"),
limit: int = Query(20, ge=1, le=50),
admin_user_id: str = Security(get_user_id),
) -> AdminCopilotUsersResponse:
logger.info(
"Admin user %s searched Copilot users (query_length=%s, limit=%s)",
admin_user_id,
len(search.strip()),
limit,
)
users = await search_users(search, limit=limit)
return AdminCopilotUsersResponse(
users=[AdminCopilotUserSummary.from_user(user) for user in users]
)
@router.post(
"/copilot/trigger",
response_model=TriggerCopilotSessionResponse,
summary="Trigger Copilot Session",
operation_id="postV2TriggerCopilotSession",
)
async def trigger_copilot_session_route(
request: TriggerCopilotSessionRequest,
admin_user_id: str = Security(get_user_id),
) -> TriggerCopilotSessionResponse:
logger.info(
"Admin user %s manually triggered %s for user %s",
admin_user_id,
request.start_type,
request.user_id,
)
try:
session = await trigger_autopilot_session_for_user(
request.user_id,
start_type=request.start_type,
)
except LookupError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
logger.info(
"Admin user %s created manual Copilot session %s for user %s",
admin_user_id,
session.session_id,
request.user_id,
)
return TriggerCopilotSessionResponse(
session_id=session.session_id,
start_type=request.start_type,
)
@router.post(
"/copilot/send-emails",
response_model=SendCopilotEmailsResponse,
summary="Send Pending Copilot Emails",
operation_id="postV2SendPendingCopilotEmails",
)
async def send_pending_copilot_emails_route(
request: SendCopilotEmailsRequest,
admin_user_id: str = Security(get_user_id),
) -> SendCopilotEmailsResponse:
logger.info(
"Admin user %s manually triggered pending Copilot emails for user %s",
admin_user_id,
request.user_id,
)
result = await send_pending_copilot_emails_for_user(request.user_id)
logger.info(
"Admin user %s completed pending Copilot email sweep for user %s "
"(candidates=%s, sent=%s, skipped=%s, repairs=%s, running=%s, failed=%s)",
admin_user_id,
request.user_id,
result.candidate_count,
result.sent_count,
result.skipped_count,
result.repair_queued_count,
result.running_count,
result.failed_count,
)
return SendCopilotEmailsResponse(**result.model_dump())

View File

@@ -8,15 +8,11 @@ import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.copilot.autopilot_email import PendingCopilotEmailSweepResult
from backend.copilot.model import ChatSession
from backend.copilot.session_types import ChatSessionStartType
from backend.data.invited_user import (
BulkInvitedUserRowResult,
BulkInvitedUsersResult,
InvitedUserRecord,
)
from backend.data.model import User
from .user_admin_routes import router as user_admin_router
@@ -76,20 +72,6 @@ def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
)
def _sample_user() -> User:
now = datetime.now(timezone.utc)
return User(
id="user-1",
email="copilot@example.com",
name="Copilot User",
timezone="Europe/Madrid",
created_at=now,
updated_at=now,
stripe_customer_id=None,
top_up_config=None,
)
def test_get_invited_users(
mocker: pytest_mock.MockerFixture,
) -> None:
@@ -184,107 +166,3 @@ def test_retry_invited_user_tally(
assert response.status_code == 200
assert response.json()["tally_status"] == "RUNNING"
def test_search_copilot_users(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.search_users",
AsyncMock(return_value=[_sample_user()]),
)
response = client.get("/admin/copilot/users", params={"search": "copilot"})
assert response.status_code == 200
data = response.json()
assert len(data["users"]) == 1
assert data["users"][0]["email"] == "copilot@example.com"
assert data["users"][0]["timezone"] == "Europe/Madrid"
def test_trigger_copilot_session(
mocker: pytest_mock.MockerFixture,
) -> None:
session = ChatSession.new(
"user-1",
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
)
trigger = mocker.patch(
"backend.api.features.admin.user_admin_routes.trigger_autopilot_session_for_user",
AsyncMock(return_value=session),
)
response = client.post(
"/admin/copilot/trigger",
json={
"user_id": "user-1",
"start_type": ChatSessionStartType.AUTOPILOT_CALLBACK.value,
},
)
assert response.status_code == 200
assert response.json()["session_id"] == session.session_id
assert response.json()["start_type"] == "AUTOPILOT_CALLBACK"
assert trigger.await_args is not None
assert trigger.await_args.args[0] == "user-1"
assert (
trigger.await_args.kwargs["start_type"]
== ChatSessionStartType.AUTOPILOT_CALLBACK
)
def test_trigger_copilot_session_returns_not_found(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.admin.user_admin_routes.trigger_autopilot_session_for_user",
AsyncMock(side_effect=LookupError("User not found with ID: missing-user")),
)
response = client.post(
"/admin/copilot/trigger",
json={
"user_id": "missing-user",
"start_type": ChatSessionStartType.AUTOPILOT_NIGHTLY.value,
},
)
assert response.status_code == 404
assert response.json()["detail"] == "User not found with ID: missing-user"
def test_send_pending_copilot_emails(
mocker: pytest_mock.MockerFixture,
) -> None:
send_emails = mocker.patch(
"backend.api.features.admin.user_admin_routes.send_pending_copilot_emails_for_user",
AsyncMock(
return_value=PendingCopilotEmailSweepResult(
candidate_count=1,
processed_count=1,
sent_count=1,
skipped_count=0,
repair_queued_count=0,
running_count=0,
failed_count=0,
)
),
)
response = client.post(
"/admin/copilot/send-emails",
json={"user_id": "user-1"},
)
assert response.status_code == 200
assert response.json() == {
"candidate_count": 1,
"processed_count": 1,
"sent_count": 1,
"skipped_count": 0,
"repair_queued_count": 0,
"running_count": 0,
"failed_count": 0,
}
send_emails.assert_awaited_once_with("user-1")

View File

@@ -3,23 +3,18 @@
import asyncio
import logging
import re
import time
from collections.abc import AsyncGenerator
from typing import Annotated, Any, NoReturn
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.autopilot import (
consume_callback_token,
strip_internal_content,
unwrap_internal_content,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
@@ -32,13 +27,7 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamHeartbeat,
)
from backend.copilot.session_types import ChatSessionStartType
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.e2b_sandbox import kill_sandbox
from backend.copilot.tools.models import (
AgentDetailsResponse,
@@ -64,7 +53,6 @@ from backend.copilot.tools.models import (
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.db_accessors import workspace_db
from backend.data.redis_client import get_redis_async
from backend.data.understanding import get_business_understanding
from backend.data.workspace import get_or_create_workspace
@@ -77,187 +65,6 @@ _UUID_RE = re.compile(
)
logger = logging.getLogger(__name__)
STREAM_QUEUE_GET_TIMEOUT_SECONDS = 10.0
STREAMING_RESPONSE_HEADERS = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
}
def _build_streaming_response(
generator: AsyncGenerator[str, None],
) -> StreamingResponse:
return StreamingResponse(
generator,
media_type="text/event-stream",
headers=STREAMING_RESPONSE_HEADERS,
)
async def _unsubscribe_stream_queue(
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse] | None,
) -> None:
if subscriber_queue is None:
return
try:
await stream_registry.unsubscribe_from_session(session_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {session_id}: {unsub_err}",
exc_info=True,
)
async def _stream_subscriber_queue(
*,
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
log_meta: dict[str, Any],
started_at: float,
label: str,
surface_errors: bool,
) -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
try:
while True:
try:
chunk = await asyncio.wait_for(
subscriber_queue.get(),
timeout=STREAM_QUEUE_GET_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
continue
chunk_count += 1
if first_chunk_type is None:
first_chunk_type = type(chunk).__name__
elapsed = (time.perf_counter() - started_at) * 1000
logger.info(
f"[TIMING] {label} first chunk at {elapsed:.1f}ms, type={first_chunk_type}",
extra={
"json_fields": {
**log_meta,
"chunk_type": first_chunk_type,
"elapsed_ms": elapsed,
}
},
)
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - started_at) * 1000
logger.info(
f"[TIMING] {label} received StreamFinish in {total_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunk_count,
"total_time_ms": total_time,
}
},
)
break
except GeneratorExit:
logger.info(
f"[TIMING] {label} client disconnected after {chunk_count} chunks",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunk_count,
"reason": "client_disconnect",
}
},
)
except Exception as exc:
elapsed = (time.perf_counter() - started_at) * 1000
logger.error(
f"[TIMING] {label} error after {elapsed:.1f}ms: {exc}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(exc)}
},
)
if surface_errors:
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
total_time = (time.perf_counter() - started_at) * 1000
logger.info(
f"[TIMING] {label} finished in {total_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"chunks_yielded": chunk_count,
"first_chunk_type": first_chunk_type,
}
},
)
yield "data: [DONE]\n\n"
async def _stream_chat_events(
*,
session_id: str,
user_id: str | None,
subscribe_from_id: str,
turn_id: str,
log_meta: dict[str, Any],
) -> AsyncGenerator[str, None]:
started_at = time.perf_counter()
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id=subscribe_from_id,
)
try:
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
async for chunk in _stream_subscriber_queue(
session_id=session_id,
subscriber_queue=subscriber_queue,
log_meta=log_meta,
started_at=started_at,
label=f"stream_chat_post[{turn_id}]",
surface_errors=True,
):
yield chunk
finally:
await _unsubscribe_stream_queue(session_id, subscriber_queue)
async def _resume_stream_events(
*,
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> AsyncGenerator[str, None]:
started_at = time.perf_counter()
try:
async for chunk in _stream_subscriber_queue(
session_id=session_id,
subscriber_queue=subscriber_queue,
log_meta={"session_id": session_id},
started_at=started_at,
label=f"resume_stream[{session_id}]",
surface_errors=False,
):
yield chunk
finally:
await _unsubscribe_stream_queue(session_id, subscriber_queue)
async def _validate_and_get_session(
@@ -311,8 +118,6 @@ class SessionDetailResponse(BaseModel):
created_at: str
updated_at: str
user_id: str | None
start_type: ChatSessionStartType
execution_tag: str | None = None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
@@ -324,8 +129,6 @@ class SessionSummaryResponse(BaseModel):
created_at: str
updated_at: str
title: str | None = None
start_type: ChatSessionStartType
execution_tag: str | None = None
is_processing: bool
@@ -357,14 +160,6 @@ class UpdateSessionTitleRequest(BaseModel):
return stripped
class ConsumeCallbackTokenRequest(BaseModel):
token: str
class ConsumeCallbackTokenResponse(BaseModel):
session_id: str
# ========== Routes ==========
@@ -376,7 +171,6 @@ async def list_sessions(
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=100),
offset: int = Query(default=0, ge=0),
with_auto: bool = Query(default=False),
) -> ListSessionsResponse:
"""
List chat sessions for the authenticated user.
@@ -392,12 +186,7 @@ async def list_sessions(
Returns:
ListSessionsResponse: List of session summaries and total count.
"""
sessions, total_count = await get_user_sessions(
user_id,
limit,
offset,
with_auto=with_auto,
)
sessions, total_count = await get_user_sessions(user_id, limit, offset)
# Batch-check Redis for active stream status on each session
processing_set: set[str] = set()
@@ -428,8 +217,6 @@ async def list_sessions(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
start_type=session.start_type,
execution_tag=session.execution_tag,
is_processing=session.session_id in processing_set,
)
for session in sessions
@@ -581,26 +368,12 @@ async def get_session(
if not session:
raise NotFoundError(f"Session {session_id} not found.")
messages = []
for message in session.messages:
payload = message.model_dump()
if message.role == "user":
visible_content = strip_internal_content(message.content)
if (
visible_content is None
and session.start_type != ChatSessionStartType.MANUAL
):
visible_content = unwrap_internal_content(message.content)
if visible_content is None:
continue
payload["content"] = visible_content
messages.append(payload)
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
session_id,
user_id,
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
@@ -621,28 +394,11 @@ async def get_session(
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
start_type=session.start_type,
execution_tag=session.execution_tag,
messages=messages,
active_stream=active_stream_info,
)
@router.post(
"/sessions/callback-token/consume",
dependencies=[Security(auth.requires_user)],
)
async def consume_callback_token_route(
request: ConsumeCallbackTokenRequest,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> ConsumeCallbackTokenResponse:
try:
result = await consume_callback_token(request.token, user_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return ConsumeCallbackTokenResponse(session_id=result.session_id)
@router.post(
"/sessions/{session_id}/cancel",
status_code=200,
@@ -716,6 +472,9 @@ async def stream_chat_post(
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
import time
stream_start_time = time.perf_counter()
log_meta = {"component": "ChatStream", "session_id": session_id}
if user_id:
@@ -747,14 +506,18 @@ async def stream_chat_post(
if valid_ids:
workspace = await get_or_create_workspace(user_id)
files = await workspace_db().get_workspace_files_by_ids(
workspace_id=workspace.id,
file_ids=valid_ids,
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mime_type}, {round(wf.size_bytes / 1024, 1)} KB), file_id={wf.id}"
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
@@ -824,14 +587,141 @@ async def stream_chat_post(
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
return _build_streaming_response(
_stream_chat_events(
session_id=session_id,
user_id=user_id,
subscribe_from_id=subscribe_from_id,
turn_id=turn_id,
log_meta=log_meta,
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id=subscribe_from_id,
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
logger.info(
"[TIMING] Starting to read from subscriber_queue",
extra={"json_fields": log_meta},
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
chunks_yielded += 1
if not first_chunk_yielded:
first_chunk_yielded = True
elapsed = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
f"type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"elapsed_ms": elapsed * 1000,
}
},
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
f"n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"total_time_ms": total_time * 1000,
}
},
)
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
logger.info(
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"chunks_yielded": chunks_yielded,
"reason": "client_disconnect",
}
},
)
pass # Client disconnected - background task continues
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
extra={
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {session_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time * 1000,
"chunks_yielded": chunks_yielded,
}
},
)
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -857,7 +747,11 @@ async def resume_session_stream(
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
"""
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
import asyncio
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
if not active_session:
return Response(status_code=204)
@@ -874,11 +768,64 @@ async def resume_session_stream(
if subscriber_queue is None:
return Response(status_code=204)
return _build_streaming_response(
_resume_stream_events(
session_id=session_id,
subscriber_queue=subscriber_queue,
)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
exc_info=True,
)
logger.info(
"Resume stream completed",
extra={
"session_id": session_id,
"n_chunks": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@@ -1030,6 +977,6 @@ ToolResponseUnion = (
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> NoReturn:
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -1,8 +1,5 @@
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
import asyncio
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import fastapi
@@ -11,9 +8,6 @@ import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import StreamFinish
from backend.copilot.session_types import ChatSessionStartType
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
@@ -121,238 +115,6 @@ def test_update_title_not_found(
assert response.status_code == 404
def test_list_sessions_defaults_to_manual_only(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
started_at = datetime.now(timezone.utc)
mock_get_user_sessions = mocker.patch(
"backend.api.features.chat.routes.get_user_sessions",
new_callable=AsyncMock,
return_value=(
[
SimpleNamespace(
session_id="sess-1",
started_at=started_at,
updated_at=started_at,
title="Nightly check-in",
start_type=chat_routes.ChatSessionStartType.AUTOPILOT_NIGHTLY,
execution_tag="autopilot-nightly:2026-03-13",
)
],
1,
),
)
pipe = MagicMock()
pipe.hget = MagicMock()
pipe.execute = AsyncMock(return_value=["running"])
redis = MagicMock()
redis.pipeline = MagicMock(return_value=pipe)
mocker.patch(
"backend.api.features.chat.routes.get_redis_async",
new_callable=AsyncMock,
return_value=redis,
)
response = client.get("/sessions")
assert response.status_code == 200
assert response.json() == {
"sessions": [
{
"id": "sess-1",
"created_at": started_at.isoformat(),
"updated_at": started_at.isoformat(),
"title": "Nightly check-in",
"start_type": "AUTOPILOT_NIGHTLY",
"execution_tag": "autopilot-nightly:2026-03-13",
"is_processing": True,
}
],
"total": 1,
}
mock_get_user_sessions.assert_awaited_once_with(
test_user_id,
50,
0,
with_auto=False,
)
def test_list_sessions_can_include_auto_sessions(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_get_user_sessions = mocker.patch(
"backend.api.features.chat.routes.get_user_sessions",
new_callable=AsyncMock,
return_value=([], 0),
)
response = client.get("/sessions?with_auto=true")
assert response.status_code == 200
assert response.json() == {"sessions": [], "total": 0}
mock_get_user_sessions.assert_awaited_once_with(
test_user_id,
50,
0,
with_auto=True,
)
def test_consume_callback_token_route_returns_session_id(
mocker: pytest_mock.MockerFixture,
) -> None:
mock_consume = mocker.patch(
"backend.api.features.chat.routes.consume_callback_token",
new_callable=AsyncMock,
return_value=SimpleNamespace(session_id="sess-2"),
)
response = client.post(
"/sessions/callback-token/consume",
json={"token": "token-123"},
)
assert response.status_code == 200
assert response.json() == {"session_id": "sess-2"}
mock_consume.assert_awaited_once_with("token-123", TEST_USER_ID)
def test_consume_callback_token_route_returns_404_on_invalid_token(
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"backend.api.features.chat.routes.consume_callback_token",
new_callable=AsyncMock,
side_effect=ValueError("Callback token not found"),
)
response = client.post(
"/sessions/callback-token/consume",
json={"token": "token-123"},
)
assert response.status_code == 404
assert response.json() == {"detail": "Callback token not found"}
def test_get_session_hides_internal_only_messages_for_manual_sessions(
mocker: pytest_mock.MockerFixture,
) -> None:
session = ChatSession.new(
TEST_USER_ID,
start_type=ChatSessionStartType.MANUAL,
)
session.messages = [
ChatMessage(role="user", content="<internal>hidden</internal>"),
ChatMessage(
role="user",
content="Visible<internal>hidden</internal> text",
),
ChatMessage(role="assistant", content="Public response"),
]
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=session,
)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(None, None),
)
response = client.get(f"/sessions/{session.session_id}")
assert response.status_code == 200
assert response.json()["messages"] == [
{
"role": "user",
"content": "Visible text",
"name": None,
"tool_call_id": None,
"refusal": None,
"tool_calls": None,
"function_call": None,
},
{
"role": "assistant",
"content": "Public response",
"name": None,
"tool_call_id": None,
"refusal": None,
"tool_calls": None,
"function_call": None,
},
]
def test_get_session_shows_cleaned_internal_kickoff_for_autopilot_sessions(
mocker: pytest_mock.MockerFixture,
) -> None:
session = ChatSession.new(
TEST_USER_ID,
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
execution_tag="autopilot-nightly:2026-03-13",
)
session.messages = [
ChatMessage(role="user", content="<internal>hidden</internal>"),
ChatMessage(
role="user",
content="Visible<internal>hidden</internal> text",
),
ChatMessage(role="assistant", content="Public response"),
]
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=session,
)
mocker.patch(
"backend.api.features.chat.routes.stream_registry.get_active_session",
new_callable=AsyncMock,
return_value=(None, None),
)
response = client.get(f"/sessions/{session.session_id}")
assert response.status_code == 200
assert response.json()["messages"] == [
{
"role": "user",
"content": "hidden",
"name": None,
"tool_call_id": None,
"refusal": None,
"tool_calls": None,
"function_call": None,
},
{
"role": "user",
"content": "Visible text",
"name": None,
"tool_call_id": None,
"refusal": None,
"tool_calls": None,
"function_call": None,
},
{
"role": "assistant",
"content": "Public response",
"name": None,
"tool_call_id": None,
"refusal": None,
"tool_calls": None,
"function_call": None,
},
]
# ─── file_ids Pydantic validation ─────────────────────────────────────
@@ -380,11 +142,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
return_value=None,
)
mock_registry = mocker.MagicMock()
subscriber_queue = asyncio.Queue()
subscriber_queue.put_nowait(StreamFinish())
mock_registry.create_session = mocker.AsyncMock(return_value=None)
mock_registry.subscribe_to_session = mocker.AsyncMock(return_value=subscriber_queue)
mock_registry.unsubscribe_from_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
@@ -407,11 +165,11 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
workspace_store = mocker.MagicMock()
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"backend.api.features.chat.routes.workspace_db",
return_value=workspace_store,
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
@@ -437,11 +195,11 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
return_value=type("W", (), {"id": "ws-1"})(),
)
workspace_store = mocker.MagicMock()
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"backend.api.features.chat.routes.workspace_db",
return_value=workspace_store,
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
@@ -459,10 +217,9 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
)
# The find_many call should only receive the one valid UUID
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
workspace_id="ws-1",
file_ids=[valid_id],
)
mock_prisma.find_many.assert_called_once()
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ─── Cross-workspace file_ids ─────────────────────────────────────────
@@ -476,11 +233,11 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
workspace_store = mocker.MagicMock()
workspace_store.get_workspace_files_by_ids = mocker.AsyncMock(return_value=[])
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"backend.api.features.chat.routes.workspace_db",
return_value=workspace_store,
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
@@ -489,10 +246,9 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
json={"message": "hi", "file_ids": [fid]},
)
workspace_store.get_workspace_files_by_ids.assert_called_once_with(
workspace_id="my-workspace-id",
file_ids=[fid],
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False
# ─── Suggested prompts endpoint ──────────────────────────────────────

View File

@@ -11,7 +11,10 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import parse_data_uri, resolve_media_content
from backend.util.type import MediaFileType
from ._api import get_api
from ._auth import (
@@ -178,7 +181,8 @@ class FileOperation(StrEnum):
class FileOperationInput(TypedDict):
path: str
content: str
# MediaFileType is a str NewType — no runtime breakage for existing callers.
content: MediaFileType
operation: FileOperation
@@ -275,11 +279,11 @@ class GithubMultiFileCommitBlock(Block):
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str) -> str:
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": "utf-8"},
json={"content": content, "encoding": encoding},
)
return blob_response.json()["sha"]
@@ -301,10 +305,19 @@ class GithubMultiFileCommitBlock(Block):
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently
# Create all blobs concurrently. Data URIs (from store_media_file)
# are sent as base64 blobs to preserve binary content.
if upsert_files:
async def _make_blob(content: str) -> str:
parsed = parse_data_uri(content)
if parsed is not None:
_, b64_payload = parsed
return await _create_blob(b64_payload, encoding="base64")
return await _create_blob(content)
blob_shas = await asyncio.gather(
*[_create_blob(content) for _, content in upsert_files]
*[_make_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
@@ -358,15 +371,36 @@ class GithubMultiFileCommitBlock(Block):
input_data: Input,
*,
credentials: GithubCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
# Resolve media references (workspace://, data:, URLs) to data
# URIs so _make_blob can send binary content correctly.
resolved_files: list[FileOperationInput] = []
for file_op in input_data.files:
content = file_op.get("content", "")
operation = FileOperation(file_op.get("operation", "upsert"))
if operation != FileOperation.DELETE:
content = await resolve_media_content(
MediaFileType(content),
execution_context,
return_format="for_external_api",
)
resolved_files.append(
FileOperationInput(
path=file_op["path"],
content=MediaFileType(content),
operation=operation,
)
)
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
input_data.files,
resolved_files,
)
yield "sha", sha
yield "url", url

View File

@@ -8,6 +8,7 @@ from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.data.execution import ExecutionContext
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
@@ -97,7 +98,11 @@ async def test_multi_file_commit_error_path():
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
async for _ in block.execute(
input_data,
credentials=TEST_CREDENTIALS,
execution_context=ExecutionContext(),
):
pass

View File

@@ -1,135 +0,0 @@
"""Autopilot public API — thin facade re-exporting from sub-modules.
Implementation is split by responsibility:
- autopilot_prompts: constants, prompt templates, context builders
- autopilot_dispatch: timezone helpers, session creation, dispatch/scheduling
- autopilot_completion: completion report extraction, repair, handler
- autopilot_email: email sending, link building, notification sweep
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from pydantic import BaseModel
from backend.copilot.autopilot_completion import ( # noqa: F401
CompletionReportToolCall,
CompletionReportToolCallFunction,
ToolOutputEnvelope,
_build_completion_report_repair_message,
_extract_completion_report_from_session,
_get_pending_approval_metadata,
_queue_completion_report_repair,
handle_non_manual_session_completion,
)
from backend.copilot.autopilot_dispatch import ( # noqa: F401
_bucket_end_for_now,
_create_autopilot_session,
_crosses_local_midnight,
_enqueue_session_turn,
_resolve_timezone_name,
_session_exists_for_execution_tag,
_try_create_callback_session,
_try_create_invite_cta_session,
_try_create_nightly_session,
_user_has_recent_manual_message,
_user_has_session_since,
dispatch_nightly_copilot,
get_callback_execution_tag,
get_graph_exec_id_for_session,
get_invite_cta_execution_tag,
get_nightly_execution_tag,
trigger_autopilot_session_for_user,
)
from backend.copilot.autopilot_email import ( # noqa: F401
PendingCopilotEmailSweepResult,
_build_session_link,
_get_completion_email_template_name,
_markdown_to_email_html,
_send_completion_email,
_send_nightly_copilot_emails,
send_nightly_copilot_emails,
send_pending_copilot_emails_for_user,
)
from backend.copilot.autopilot_prompts import ( # noqa: F401
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE,
AUTOPILOT_CALLBACK_TAG,
AUTOPILOT_DISABLED_TOOLS,
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE,
AUTOPILOT_INVITE_CTA_TAG,
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE,
AUTOPILOT_NIGHTLY_TAG_PREFIX,
DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT,
DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT,
DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT,
INTERNAL_TAG_RE,
MAX_COMPLETION_REPORT_REPAIRS,
_build_autopilot_system_prompt,
_format_start_type_label,
_get_recent_manual_session_context,
_get_recent_sent_email_context,
_get_recent_session_summary_context,
strip_internal_content,
unwrap_internal_content,
wrap_internal_message,
)
from backend.copilot.model import ChatMessage, create_chat_session
from backend.data.db_accessors import chat_db
# -- re-exports from sub-modules (preserves existing import paths) ---------- #
logger = logging.getLogger(__name__)
class CallbackTokenConsumeResult(BaseModel):
session_id: str
async def consume_callback_token(
token_id: str,
user_id: str,
) -> CallbackTokenConsumeResult:
"""Consume a callback token and return the resulting session.
Uses an atomic consume-and-update to prevent the TOCTOU race where two
concurrent requests could each see the token as unconsumed and create
duplicate sessions.
"""
db = chat_db()
token = await db.get_chat_session_callback_token(token_id)
if token is None or token.user_id != user_id:
raise ValueError("Callback token not found")
if token.expires_at <= datetime.now(UTC):
raise ValueError("Callback token has expired")
if token.consumed_session_id:
return CallbackTokenConsumeResult(session_id=token.consumed_session_id)
session = await create_chat_session(
user_id,
initial_messages=[
ChatMessage(role="assistant", content=token.callback_session_message)
],
)
# Atomically mark consumed — if another request already consumed the token
# concurrently, the DB will have a non-null consumed_session_id; we re-read
# and return the winner's session instead.
await db.mark_chat_session_callback_token_consumed(
token_id,
session.session_id,
)
# Re-read to see if we won the race
refreshed = await db.get_chat_session_callback_token(token_id)
if (
refreshed
and refreshed.consumed_session_id
and refreshed.consumed_session_id != session.session_id
):
return CallbackTokenConsumeResult(session_id=refreshed.consumed_session_id)
return CallbackTokenConsumeResult(session_id=session.session_id)

View File

@@ -1,198 +0,0 @@
from __future__ import annotations
import logging
from datetime import UTC, datetime
from pydantic import BaseModel, Field, ValidationError
from backend.copilot.autopilot_dispatch import (
_enqueue_session_turn,
get_graph_exec_id_for_session,
)
from backend.copilot.autopilot_prompts import (
MAX_COMPLETION_REPORT_REPAIRS,
wrap_internal_message,
)
from backend.copilot.model import (
ChatMessage,
ChatSession,
get_chat_session,
upsert_chat_session,
)
from backend.copilot.session_types import CompletionReportInput, StoredCompletionReport
from backend.data.db_accessors import review_db
logger = logging.getLogger(__name__)
# --------------- models --------------- #
class CompletionReportToolCallFunction(BaseModel):
name: str | None = None
arguments: str | None = None
class CompletionReportToolCall(BaseModel):
id: str
function: CompletionReportToolCallFunction = Field(
default_factory=CompletionReportToolCallFunction
)
class ToolOutputEnvelope(BaseModel):
type: str | None = None
# --------------- approval metadata --------------- #
async def _get_pending_approval_metadata(
session: ChatSession,
) -> tuple[int, str | None]:
graph_exec_id = get_graph_exec_id_for_session(session.session_id)
pending_count = await review_db().count_pending_reviews_for_graph_exec(
graph_exec_id,
session.user_id,
)
return pending_count, graph_exec_id if pending_count > 0 else None
# --------------- extraction --------------- #
def _extract_completion_report_from_session(
session: ChatSession,
*,
pending_approval_count: int,
) -> CompletionReportInput | None:
tool_outputs = {
message.tool_call_id: message.content
for message in session.messages
if message.role == "tool" and message.tool_call_id
}
latest_report: CompletionReportInput | None = None
for message in session.messages:
if message.role != "assistant" or not message.tool_calls:
continue
for tool_call in message.tool_calls:
try:
parsed_tool_call = CompletionReportToolCall.model_validate(tool_call)
except ValidationError:
continue
if parsed_tool_call.function.name != "completion_report":
continue
output = tool_outputs.get(parsed_tool_call.id)
if not output:
continue
try:
output_payload = ToolOutputEnvelope.model_validate_json(output)
except ValidationError:
output_payload = None
if output_payload is not None and output_payload.type == "error":
continue
try:
raw_arguments = parsed_tool_call.function.arguments or "{}"
report = CompletionReportInput.model_validate_json(raw_arguments)
except ValidationError:
continue
if pending_approval_count > 0 and not report.approval_summary:
continue
latest_report = report
return latest_report
# --------------- repair --------------- #
def _build_completion_report_repair_message(
*,
attempt: int,
pending_approval_count: int,
) -> str:
approval_instruction = ""
if pending_approval_count > 0:
approval_instruction = (
f" There are currently {pending_approval_count} pending approval item(s). "
"If they still exist, include approval_summary."
)
return wrap_internal_message(
"The session completed without a valid completion_report tool call. "
f"This is repair attempt {attempt}. Call completion_report now and do not do any additional user-facing work."
+ approval_instruction
)
async def _queue_completion_report_repair(
session: ChatSession,
*,
pending_approval_count: int,
) -> None:
attempt = session.completion_report_repair_count + 1
repair_message = _build_completion_report_repair_message(
attempt=attempt,
pending_approval_count=pending_approval_count,
)
session.messages.append(ChatMessage(role="user", content=repair_message))
session.completion_report_repair_count = attempt
session.completion_report_repair_queued_at = datetime.now(UTC)
session.completed_at = None
session.completion_report = None
await upsert_chat_session(session)
await _enqueue_session_turn(
session,
message=repair_message,
tool_name="completion_report_repair",
)
# --------------- handler --------------- #
async def handle_non_manual_session_completion(session_id: str) -> None:
session = await get_chat_session(session_id)
if session is None or session.is_manual:
return
pending_approval_count, graph_exec_id = await _get_pending_approval_metadata(
session
)
report = _extract_completion_report_from_session(
session,
pending_approval_count=pending_approval_count,
)
if report is not None:
session.completion_report = StoredCompletionReport(
**report.model_dump(),
has_pending_approvals=pending_approval_count > 0,
pending_approval_count=pending_approval_count,
pending_approval_graph_exec_id=graph_exec_id,
saved_at=datetime.now(UTC),
)
session.completion_report_repair_queued_at = None
session.completed_at = datetime.now(UTC)
await upsert_chat_session(session)
return
if session.completion_report_repair_count >= MAX_COMPLETION_REPORT_REPAIRS:
session.completion_report_repair_queued_at = None
session.completed_at = datetime.now(UTC)
await upsert_chat_session(session)
return
await _queue_completion_report_repair(
session,
pending_approval_count=pending_approval_count,
)

View File

@@ -1,386 +0,0 @@
from __future__ import annotations
import logging
from datetime import UTC, date, datetime, time, timedelta
from typing import TYPE_CHECKING
from uuid import uuid4
from zoneinfo import ZoneInfo
import prisma.enums
from backend.copilot import stream_registry
from backend.copilot.autopilot_prompts import (
AUTOPILOT_CALLBACK_TAG,
AUTOPILOT_DISABLED_TOOLS,
AUTOPILOT_INVITE_CTA_TAG,
AUTOPILOT_NIGHTLY_TAG_PREFIX,
_build_autopilot_system_prompt,
_render_initial_message,
)
from backend.copilot.constants import COPILOT_SESSION_PREFIX
from backend.copilot.executor.utils import enqueue_copilot_turn
from backend.copilot.model import ChatMessage, ChatSession, create_chat_session
from backend.copilot.session_types import ChatSessionConfig, ChatSessionStartType
from backend.data.db_accessors import chat_db, invited_user_db, user_db
from backend.data.model import User
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.settings import Settings
from backend.util.timezone_utils import get_user_timezone_or_utc
if TYPE_CHECKING:
from backend.data.invited_user import InvitedUserRecord
logger = logging.getLogger(__name__)
settings = Settings()
DISPATCH_BATCH_SIZE = 500
# --------------- tag helpers --------------- #
def get_graph_exec_id_for_session(session_id: str) -> str:
return f"{COPILOT_SESSION_PREFIX}{session_id}"
def get_nightly_execution_tag(target_local_date: date) -> str:
return f"{AUTOPILOT_NIGHTLY_TAG_PREFIX}{target_local_date.isoformat()}"
def get_callback_execution_tag() -> str:
return AUTOPILOT_CALLBACK_TAG
def get_invite_cta_execution_tag() -> str:
return AUTOPILOT_INVITE_CTA_TAG
def _get_manual_trigger_execution_tag(start_type: ChatSessionStartType) -> str:
timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%S%fZ")
return f"admin-autopilot:{start_type.value}:{timestamp}:{uuid4()}"
# --------------- timezone helpers --------------- #
def _bucket_end_for_now(now_utc: datetime) -> datetime:
minute = 30 if now_utc.minute >= 30 else 0
return now_utc.replace(minute=minute, second=0, microsecond=0)
def _resolve_timezone_name(raw_timezone: str | None) -> str:
return get_user_timezone_or_utc(raw_timezone)
def _crosses_local_midnight(
bucket_start_utc: datetime,
bucket_end_utc: datetime,
timezone_name: str,
) -> date | None:
"""Return the new local date if *bucket_end_utc* falls on a different local
date than *bucket_start_utc*, taking DST transitions into account.
During a DST spring-forward the wall clock jumps forward (e.g. 01:59 →
03:00). We use ``fold=0`` on the end instant so that ambiguous/missing
times are resolved consistently and a single 30-min UTC bucket can never
produce midnight on *two* consecutive calls.
"""
tz = ZoneInfo(timezone_name)
start_local = bucket_start_utc.astimezone(tz)
end_local = bucket_end_utc.astimezone(tz)
# Resolve ambiguous wall-clock times consistently (spring-forward / fall-back)
end_local = end_local.replace(fold=0)
if start_local.date() == end_local.date():
return None
return end_local.date()
# --------------- thin DB wrappers --------------- #
async def _user_has_recent_manual_message(user_id: str, since: datetime) -> bool:
return await chat_db().has_recent_manual_message(user_id, since)
async def _user_has_session_since(user_id: str, since: datetime) -> bool:
return await chat_db().has_session_since(user_id, since)
async def _session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
return await chat_db().session_exists_for_execution_tag(user_id, execution_tag)
# --------------- session creation --------------- #
async def _enqueue_session_turn(
session: ChatSession,
*,
message: str,
tool_name: str,
) -> None:
turn_id = str(uuid4())
await stream_registry.create_session(
session_id=session.session_id,
user_id=session.user_id,
tool_call_id=tool_name,
tool_name=tool_name,
turn_id=turn_id,
blocking=False,
)
await enqueue_copilot_turn(
session_id=session.session_id,
user_id=session.user_id,
message=message,
turn_id=turn_id,
is_user_message=True,
)
async def _create_autopilot_session(
user: User,
*,
start_type: ChatSessionStartType,
execution_tag: str,
timezone_name: str,
target_local_date: date | None = None,
invited_user: InvitedUserRecord | None = None,
) -> ChatSession | None:
if await _session_exists_for_execution_tag(user.id, execution_tag):
return None
system_prompt = await _build_autopilot_system_prompt(
user,
start_type=start_type,
timezone_name=timezone_name,
target_local_date=target_local_date,
invited_user=invited_user,
)
initial_message = _render_initial_message(
start_type,
user_name=user.name,
invited_user=invited_user,
)
session_config = ChatSessionConfig(
system_prompt_override=system_prompt,
initial_user_message=initial_message,
extra_tools=["completion_report"],
disabled_tools=AUTOPILOT_DISABLED_TOOLS,
)
session = await create_chat_session(
user.id,
start_type=start_type,
execution_tag=execution_tag,
session_config=session_config,
initial_messages=[ChatMessage(role="user", content=initial_message)],
)
await _enqueue_session_turn(
session,
message=initial_message,
tool_name="autopilot_dispatch",
)
return session
# --------------- cohort helpers --------------- #
async def _try_create_invite_cta_session(
user: User,
*,
invited_user: InvitedUserRecord | None,
now_utc: datetime,
timezone_name: str,
invite_cta_start: date,
invite_cta_delay: timedelta,
) -> bool:
if invited_user is None:
return False
if invited_user.status != prisma.enums.InvitedUserStatus.INVITED:
return False
if invited_user.created_at.date() < invite_cta_start:
return False
if invited_user.created_at > now_utc - invite_cta_delay:
return False
if await _session_exists_for_execution_tag(user.id, get_invite_cta_execution_tag()):
return False
created = await _create_autopilot_session(
user,
start_type=ChatSessionStartType.AUTOPILOT_INVITE_CTA,
execution_tag=get_invite_cta_execution_tag(),
timezone_name=timezone_name,
invited_user=invited_user,
)
return created is not None
async def _try_create_nightly_session(
user: User,
*,
now_utc: datetime,
timezone_name: str,
target_local_date: date,
) -> bool:
if not await _user_has_recent_manual_message(
user.id,
now_utc - timedelta(hours=24),
):
return False
created = await _create_autopilot_session(
user,
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
execution_tag=get_nightly_execution_tag(target_local_date),
timezone_name=timezone_name,
target_local_date=target_local_date,
)
return created is not None
async def _try_create_callback_session(
user: User,
*,
callback_start: datetime,
timezone_name: str,
) -> bool:
if not await _user_has_session_since(user.id, callback_start):
return False
if await _session_exists_for_execution_tag(user.id, get_callback_execution_tag()):
return False
created = await _create_autopilot_session(
user,
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
execution_tag=get_callback_execution_tag(),
timezone_name=timezone_name,
)
return created is not None
# --------------- dispatch --------------- #
async def _dispatch_nightly_copilot() -> int:
now_utc = datetime.now(UTC)
bucket_end = _bucket_end_for_now(now_utc)
bucket_start = bucket_end - timedelta(minutes=30)
callback_start = datetime.combine(
settings.config.nightly_copilot_callback_start_date,
time.min,
tzinfo=UTC,
)
invite_cta_start = settings.config.nightly_copilot_invite_cta_start_date
invite_cta_delay = timedelta(
hours=settings.config.nightly_copilot_invite_cta_delay_hours
)
# Paginate user list to avoid loading the entire table into memory.
created_count = 0
cursor: str | None = None
while True:
batch = await user_db().list_users(
limit=DISPATCH_BATCH_SIZE,
cursor=cursor,
)
if not batch:
break
user_ids = [user.id for user in batch]
invites = await invited_user_db().list_invited_users_for_auth_users(user_ids)
invites_by_user_id = {
invite.auth_user_id: invite for invite in invites if invite.auth_user_id
}
for user in batch:
if not await is_feature_enabled(
Flag.NIGHTLY_COPILOT, user.id, default=False
):
continue
timezone_name = _resolve_timezone_name(user.timezone)
target_local_date = _crosses_local_midnight(
bucket_start,
bucket_end,
timezone_name,
)
if target_local_date is None:
continue
invited_user = invites_by_user_id.get(user.id)
if await _try_create_invite_cta_session(
user,
invited_user=invited_user,
now_utc=now_utc,
timezone_name=timezone_name,
invite_cta_start=invite_cta_start,
invite_cta_delay=invite_cta_delay,
):
created_count += 1
continue
if await _try_create_nightly_session(
user,
now_utc=now_utc,
timezone_name=timezone_name,
target_local_date=target_local_date,
):
created_count += 1
continue
if await _try_create_callback_session(
user,
callback_start=callback_start,
timezone_name=timezone_name,
):
created_count += 1
cursor = batch[-1].id if len(batch) == DISPATCH_BATCH_SIZE else None
if cursor is None:
break
return created_count
async def dispatch_nightly_copilot() -> int:
return await _dispatch_nightly_copilot()
async def trigger_autopilot_session_for_user(
user_id: str,
*,
start_type: ChatSessionStartType,
) -> ChatSession:
allowed_start_types = {
ChatSessionStartType.AUTOPILOT_INVITE_CTA,
ChatSessionStartType.AUTOPILOT_NIGHTLY,
ChatSessionStartType.AUTOPILOT_CALLBACK,
}
if start_type not in allowed_start_types:
raise ValueError(f"Unsupported autopilot start type: {start_type}")
try:
user = await user_db().get_user_by_id(user_id)
except ValueError as exc:
raise LookupError(str(exc)) from exc
invites = await invited_user_db().list_invited_users_for_auth_users([user_id])
invited_user = invites[0] if invites else None
timezone_name = _resolve_timezone_name(user.timezone)
target_local_date = None
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
target_local_date = datetime.now(UTC).astimezone(ZoneInfo(timezone_name)).date()
session = await _create_autopilot_session(
user,
start_type=start_type,
execution_tag=_get_manual_trigger_execution_tag(start_type),
timezone_name=timezone_name,
target_local_date=target_local_date,
invited_user=invited_user,
)
if session is None:
raise ValueError("Failed to create autopilot session")
return session

View File

@@ -1,297 +0,0 @@
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime
from markdown_it import MarkdownIt
from pydantic import BaseModel
from backend.copilot import stream_registry
from backend.copilot.autopilot_completion import (
_get_pending_approval_metadata,
_queue_completion_report_repair,
)
from backend.copilot.autopilot_prompts import (
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE,
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE,
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE,
MAX_COMPLETION_REPORT_REPAIRS,
)
from backend.copilot.model import (
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.service import _generate_session_title
from backend.copilot.session_types import ChatSessionStartType
from backend.data.db_accessors import chat_db, user_db
from backend.notifications.email import EmailSender
from backend.util.url import get_frontend_base_url
logger = logging.getLogger(__name__)
PENDING_NOTIFICATION_SWEEP_LIMIT = 200
_md = MarkdownIt()
_EMAIL_INLINE_STYLES: list[tuple[str, str]] = [
(
"<p>",
'<p style="font-size: 15px; line-height: 170%;'
" margin-top: 0; margin-bottom: 16px;"
' color: #1F1F20;">',
),
(
"<li>",
'<li style="font-size: 15px; line-height: 170%;'
" margin-top: 0; margin-bottom: 8px;"
' color: #1F1F20;">',
),
(
"<ul>",
'<ul style="padding: 0 0 0 24px;' ' margin-top: 0; margin-bottom: 16px;">',
),
(
"<ol>",
'<ol style="padding: 0 0 0 24px;' ' margin-top: 0; margin-bottom: 16px;">',
),
(
"<a ",
'<a style="color: #7733F5;'
" text-decoration: underline;"
' font-weight: 500;" ',
),
(
"<h2>",
'<h2 style="font-size: 20px; font-weight: 600;'
' margin-top: 0; margin-bottom: 12px; color: #1F1F20;">',
),
(
"<h3>",
'<h3 style="font-size: 18px; font-weight: 600;'
' margin-top: 0; margin-bottom: 12px; color: #1F1F20;">',
),
]
def _markdown_to_email_html(text: str | None) -> str:
"""Convert markdown text to email-safe HTML with inline styles."""
if not text or not text.strip():
return ""
html = _md.render(text.strip())
for tag, styled_tag in _EMAIL_INLINE_STYLES:
html = html.replace(tag, styled_tag)
return html.strip()
# --------------- link builders --------------- #
def _build_session_link(session_id: str, *, show_autopilot: bool) -> str:
base_url = get_frontend_base_url()
suffix = "&showAutopilot=1" if show_autopilot else ""
return f"{base_url}/copilot?sessionId={session_id}{suffix}"
def _get_completion_email_template_name(start_type: ChatSessionStartType) -> str:
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
return AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
return AUTOPILOT_CALLBACK_EMAIL_TEMPLATE
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
return AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE
raise ValueError(f"Unsupported start type for completion email: {start_type}")
class PendingCopilotEmailSweepResult(BaseModel):
candidate_count: int = 0
processed_count: int = 0
sent_count: int = 0
skipped_count: int = 0
repair_queued_count: int = 0
running_count: int = 0
failed_count: int = 0
async def _ensure_session_title_for_completed_session(session: ChatSession) -> None:
if session.title or not session.user_id:
return
report = session.completion_report
if report is None:
return
title = report.email_title.strip() if report.email_title else ""
if not title:
title_seed = report.email_body or report.thoughts
if title_seed:
generated_title = await _generate_session_title(
title_seed,
user_id=session.user_id,
session_id=session.session_id,
)
title = generated_title.strip() if generated_title else ""
if not title:
return
updated = await update_session_title(
session.session_id,
session.user_id,
title,
only_if_empty=True,
)
if updated:
session.title = title
# --------------- send email --------------- #
async def _send_completion_email(session: ChatSession) -> None:
report = session.completion_report
if report is None:
raise ValueError("Missing completion report")
try:
user = await user_db().get_user_by_id(session.user_id)
except ValueError as exc:
raise ValueError(f"User {session.user_id} not found") from exc
if not user.email:
raise ValueError(f"User {session.user_id} not found")
approval_cta = report.has_pending_approvals
template_name = _get_completion_email_template_name(session.start_type)
if approval_cta:
cta_url = _build_session_link(session.session_id, show_autopilot=True)
cta_label = "Review in Copilot"
else:
cta_url = _build_session_link(session.session_id, show_autopilot=True)
cta_label = (
"Try Copilot"
if session.start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA
else "Open Copilot"
)
# EmailSender.send_template is synchronous (blocking HTTP call to Postmark).
# Run it in a thread to avoid blocking the async event loop.
sender = EmailSender()
await asyncio.to_thread(
sender.send_template,
user_email=user.email,
subject=report.email_title or "Autopilot update",
template_name=template_name,
data={
"email_body_html": _markdown_to_email_html(report.email_body),
"approval_summary_html": _markdown_to_email_html(report.approval_summary),
"cta_url": cta_url,
"cta_label": cta_label,
},
)
# --------------- email sweep --------------- #
async def _process_pending_copilot_email_candidates(
candidates: list,
) -> PendingCopilotEmailSweepResult:
result = PendingCopilotEmailSweepResult(candidate_count=len(candidates))
for candidate in candidates:
session = await get_chat_session(candidate.session_id)
if session is None or session.is_manual:
continue
active = await stream_registry.get_session(session.session_id)
is_running = active is not None and active.status == "running"
if is_running:
result.running_count += 1
continue
pending_approval_count, graph_exec_id = await _get_pending_approval_metadata(
session
)
if session.completion_report is None:
if session.completion_report_repair_count < MAX_COMPLETION_REPORT_REPAIRS:
await _queue_completion_report_repair(
session,
pending_approval_count=pending_approval_count,
)
result.repair_queued_count += 1
continue
session.completed_at = session.completed_at or datetime.now(UTC)
session.completion_report_repair_queued_at = None
session.notification_email_skipped_at = datetime.now(UTC)
await upsert_chat_session(session)
result.skipped_count += 1
continue
session.completed_at = session.completed_at or datetime.now(UTC)
if (
session.completion_report.pending_approval_graph_exec_id is None
and graph_exec_id
):
session.completion_report = session.completion_report.model_copy(
update={
"has_pending_approvals": pending_approval_count > 0,
"pending_approval_count": pending_approval_count,
"pending_approval_graph_exec_id": graph_exec_id,
}
)
try:
await _ensure_session_title_for_completed_session(session)
except Exception:
logger.exception(
"Failed to ensure session title for session %s",
session.session_id,
)
if not session.completion_report.should_notify_user:
session.notification_email_skipped_at = datetime.now(UTC)
await upsert_chat_session(session)
result.skipped_count += 1
continue
try:
await _send_completion_email(session)
except Exception:
logger.exception(
"Failed to send nightly copilot email for session %s",
session.session_id,
)
result.failed_count += 1
continue
session.notification_email_sent_at = datetime.now(UTC)
await upsert_chat_session(session)
result.sent_count += 1
result.processed_count = result.sent_count + result.skipped_count
return result
async def _send_nightly_copilot_emails() -> int:
candidates = await chat_db().get_pending_notification_chat_sessions(
limit=PENDING_NOTIFICATION_SWEEP_LIMIT
)
result = await _process_pending_copilot_email_candidates(candidates)
return result.processed_count
async def send_nightly_copilot_emails() -> int:
return await _send_nightly_copilot_emails()
async def send_pending_copilot_emails_for_user(
user_id: str,
) -> PendingCopilotEmailSweepResult:
candidates = await chat_db().get_pending_notification_chat_sessions_for_user(
user_id,
limit=PENDING_NOTIFICATION_SWEEP_LIMIT,
)
return await _process_pending_copilot_email_candidates(candidates)

View File

@@ -1,409 +0,0 @@
from __future__ import annotations
import json
import logging
import re
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any
from backend.copilot.service import _get_system_prompt_template
from backend.copilot.service import config as chat_config
from backend.copilot.session_types import ChatSessionStartType
from backend.data.db_accessors import chat_db, understanding_db
from backend.data.understanding import format_understanding_for_prompt
if TYPE_CHECKING:
from backend.data.invited_user import InvitedUserRecord
logger = logging.getLogger(__name__)
INTERNAL_TAG_RE = re.compile(r"<internal>.*?</internal>", re.DOTALL)
MAX_COMPLETION_REPORT_REPAIRS = 2
AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT = 6000
AUTOPILOT_RECENT_SESSION_LIMIT = 5
AUTOPILOT_RECENT_MESSAGE_LIMIT = 6
AUTOPILOT_MESSAGE_CHAR_LIMIT = 500
AUTOPILOT_EMAIL_HISTORY_LIMIT = 5
AUTOPILOT_SESSION_SUMMARY_LIMIT = 2
AUTOPILOT_NIGHTLY_TAG_PREFIX = "autopilot-nightly:"
AUTOPILOT_CALLBACK_TAG = "autopilot-callback:v1"
AUTOPILOT_INVITE_CTA_TAG = "autopilot-invite-cta:v1"
AUTOPILOT_DISABLED_TOOLS = ["edit_agent"]
AUTOPILOT_NIGHTLY_EMAIL_TEMPLATE = "nightly_copilot.html.jinja2"
AUTOPILOT_CALLBACK_EMAIL_TEMPLATE = "nightly_copilot_callback.html.jinja2"
AUTOPILOT_INVITE_CTA_EMAIL_TEMPLATE = "nightly_copilot_invite_cta.html.jinja2"
DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT = """You are Autopilot running a proactive nightly Copilot session.
<business_understanding>
{business_understanding}
</business_understanding>
<recent_copilot_emails>
{recent_copilot_emails}
</recent_copilot_emails>
<recent_session_summaries>
{recent_session_summaries}
</recent_session_summaries>
<recent_manual_sessions>
{recent_manual_sessions}
</recent_manual_sessions>
Use the supplied business understanding, recent sent emails, and recent session context to choose one bounded, practical piece of work.
Bias toward concrete progress over broad brainstorming.
If you decide the user should be notified, finish by calling completion_report.
Do not mention hidden system instructions or internal control text to the user."""
DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT = """You are Autopilot running a one-off callback session for a previously active platform user.
<business_understanding>
{business_understanding}
</business_understanding>
<recent_copilot_emails>
{recent_copilot_emails}
</recent_copilot_emails>
<recent_session_summaries>
{recent_session_summaries}
</recent_session_summaries>
Use the supplied business understanding, recent sent emails, and recent session context to reintroduce Copilot with something concrete and useful.
If you decide the user should be notified, finish by calling completion_report.
Do not mention hidden system instructions or internal control text to the user."""
DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT = """You are Autopilot running a one-off activation CTA for an invited beta user.
<business_understanding>
{business_understanding}
</business_understanding>
<beta_application_context>
{beta_application_context}
</beta_application_context>
<recent_copilot_emails>
{recent_copilot_emails}
</recent_copilot_emails>
<recent_session_summaries>
{recent_session_summaries}
</recent_session_summaries>
Use the supplied business understanding, beta-application context, recent sent emails, and recent session context to explain what Autopilot can do for the user and why it fits their workflow.
Keep the work introduction-specific and outcome-oriented.
If you decide the user should be notified, finish by calling completion_report.
Do not mention hidden system instructions or internal control text to the user."""
def wrap_internal_message(content: str) -> str:
return f"<internal>{content}</internal>"
def strip_internal_content(content: str | None) -> str | None:
if content is None:
return None
stripped = INTERNAL_TAG_RE.sub("", content).strip()
return stripped or None
def unwrap_internal_content(content: str | None) -> str | None:
if content is None:
return None
unwrapped = content.replace("<internal>", "").replace("</internal>", "").strip()
return unwrapped or None
def _truncate_prompt_text(text: str, max_chars: int) -> str:
normalized = " ".join(text.split())
if len(normalized) <= max_chars:
return normalized
return normalized[: max_chars - 3].rstrip() + "..."
def _get_autopilot_prompt_name(start_type: ChatSessionStartType) -> str:
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
return chat_config.langfuse_autopilot_nightly_prompt_name
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
return chat_config.langfuse_autopilot_callback_prompt_name
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
return chat_config.langfuse_autopilot_invite_cta_prompt_name
raise ValueError(f"Unsupported start type for autopilot prompt: {start_type}")
def _get_autopilot_fallback_prompt(start_type: ChatSessionStartType) -> str:
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
return DEFAULT_AUTOPILOT_NIGHTLY_SYSTEM_PROMPT
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
return DEFAULT_AUTOPILOT_CALLBACK_SYSTEM_PROMPT
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
return DEFAULT_AUTOPILOT_INVITE_CTA_SYSTEM_PROMPT
raise ValueError(f"Unsupported start type for autopilot prompt: {start_type}")
def _format_start_type_label(start_type: ChatSessionStartType) -> str:
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
return "Nightly"
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
return "Callback"
if start_type == ChatSessionStartType.AUTOPILOT_INVITE_CTA:
return "Beta Invite CTA"
return start_type.value
def _get_invited_user_tally_understanding(
invited_user: InvitedUserRecord | None,
) -> dict[str, Any] | None:
return invited_user.tally_understanding if invited_user is not None else None
def _render_initial_message(
start_type: ChatSessionStartType,
*,
user_name: str | None,
invited_user: InvitedUserRecord | None = None,
) -> str:
display_name = user_name or "the user"
if start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY:
return wrap_internal_message(
"This is a nightly proactive Copilot session. Review recent manual activity, "
f"do one useful piece of work for {display_name}, and finish with completion_report."
)
if start_type == ChatSessionStartType.AUTOPILOT_CALLBACK:
return wrap_internal_message(
"This is a one-off callback session for a previously active user. "
f"Reintroduce Copilot with something concrete and useful for {display_name}, "
"then finish with completion_report."
)
invite_summary = ""
tally_understanding = _get_invited_user_tally_understanding(invited_user)
if tally_understanding is not None:
invite_summary = "\nKnown context from the beta application:\n" + json.dumps(
tally_understanding, ensure_ascii=False
)
return wrap_internal_message(
"This is a one-off invite CTA session for an invited beta user who has not yet activated. "
f"Create a tailored introduction for {display_name}, explain how Autopilot can help, "
f"and finish with completion_report.{invite_summary}"
)
def _get_previous_local_midnight_utc(
target_local_date: date,
timezone_name: str,
) -> datetime:
from datetime import UTC
from zoneinfo import ZoneInfo
tz = ZoneInfo(timezone_name)
previous_midnight_local = datetime.combine(
target_local_date - timedelta(days=1),
time.min,
tzinfo=tz,
)
return previous_midnight_local.astimezone(UTC)
async def _get_recent_manual_session_context(
user_id: str,
*,
since_utc: datetime,
) -> str:
sessions = await chat_db().get_manual_chat_sessions_since(
user_id,
since_utc,
AUTOPILOT_RECENT_SESSION_LIMIT,
)
if not sessions:
return "No recent manual sessions since the previous nightly run."
blocks: list[str] = []
used_chars = 0
for session in sessions:
messages = await chat_db().get_chat_messages_since(
session.session_id, since_utc
)
visible_messages: list[str] = []
for message in messages[-AUTOPILOT_RECENT_MESSAGE_LIMIT:]:
content = message.content or ""
if message.role == "user":
visible = strip_internal_content(content)
else:
visible = content.strip() or None
if not visible:
continue
role_label = {
"user": "User",
"assistant": "Assistant",
"tool": "Tool",
}.get(message.role, message.role.title())
visible_messages.append(
f"{role_label}: {_truncate_prompt_text(visible, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
)
if not visible_messages:
continue
title_suffix = f" ({session.title})" if session.title else ""
block = (
f"### Session updated {session.updated_at.isoformat()}{title_suffix}\n"
+ "\n".join(visible_messages)
)
if used_chars + len(block) > AUTOPILOT_RECENT_CONTEXT_CHAR_LIMIT:
break
blocks.append(block)
used_chars += len(block)
return (
"\n\n".join(blocks)
if blocks
else "No recent manual sessions since the previous nightly run."
)
async def _get_recent_sent_email_context(user_id: str) -> str:
sessions = await chat_db().get_recent_sent_email_chat_sessions(
user_id,
AUTOPILOT_EMAIL_HISTORY_LIMIT,
)
if not sessions:
return "No recent Copilot or Autopilot emails have been sent to this user."
blocks: list[str] = []
for session in sessions:
report = session.completion_report
sent_at = session.notification_email_sent_at
if report is None or sent_at is None:
continue
lines = [
f"### Sent {sent_at.isoformat()} ({_format_start_type_label(session.start_type)})",
]
if report.email_title:
lines.append(
f"Subject: {_truncate_prompt_text(report.email_title, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
)
if report.email_body:
lines.append(
f"Body: {_truncate_prompt_text(report.email_body, AUTOPILOT_MESSAGE_CHAR_LIMIT)}"
)
if report.callback_session_message:
lines.append(
"CTA Message: "
+ _truncate_prompt_text(
report.callback_session_message,
AUTOPILOT_MESSAGE_CHAR_LIMIT,
)
)
blocks.append("\n".join(lines))
return (
"\n\n".join(blocks)
if blocks
else "No recent Copilot or Autopilot emails have been sent to this user."
)
async def _get_recent_session_summary_context(user_id: str) -> str:
sessions = await chat_db().get_recent_completion_report_chat_sessions(
user_id,
AUTOPILOT_SESSION_SUMMARY_LIMIT,
)
if not sessions:
return "No recent Copilot session summaries are available."
blocks: list[str] = []
for session in sessions:
report = session.completion_report
if report is None:
continue
title_suffix = f" ({session.title})" if session.title else ""
lines = [
f"### {_format_start_type_label(session.start_type)} session updated {session.updated_at.isoformat()}{title_suffix}",
f"Summary: {_truncate_prompt_text(report.thoughts, AUTOPILOT_MESSAGE_CHAR_LIMIT)}",
]
if report.email_title:
lines.append(
"Email Title: "
+ _truncate_prompt_text(
report.email_title, AUTOPILOT_MESSAGE_CHAR_LIMIT
)
)
blocks.append("\n".join(lines))
return (
"\n\n".join(blocks)
if blocks
else "No recent Copilot session summaries are available."
)
async def _build_autopilot_system_prompt(
user: Any,
*,
start_type: ChatSessionStartType,
timezone_name: str,
target_local_date: date | None = None,
invited_user: InvitedUserRecord | None = None,
) -> str:
understanding = await understanding_db().get_business_understanding(user.id)
business_understanding = (
format_understanding_for_prompt(understanding)
if understanding
else "No saved business understanding yet."
)
recent_copilot_emails = await _get_recent_sent_email_context(user.id)
recent_session_summaries = await _get_recent_session_summary_context(user.id)
recent_manual_sessions = "Not applicable for this prompt type."
beta_application_context = "No beta application context available."
users_information_sections = [
"## Business Understanding\n" + business_understanding
]
users_information_sections.append(
"## Recent Copilot Emails Sent To User\n" + recent_copilot_emails
)
users_information_sections.append(
"## Recent Copilot Session Summaries\n" + recent_session_summaries
)
users_information = "\n\n".join(users_information_sections)
if (
start_type == ChatSessionStartType.AUTOPILOT_NIGHTLY
and target_local_date is not None
):
recent_manual_sessions = await _get_recent_manual_session_context(
user.id,
since_utc=_get_previous_local_midnight_utc(
target_local_date,
timezone_name,
),
)
tally_understanding = _get_invited_user_tally_understanding(invited_user)
if tally_understanding is not None:
beta_application_context = json.dumps(tally_understanding, ensure_ascii=False)
return await _get_system_prompt_template(
users_information,
prompt_name=_get_autopilot_prompt_name(start_type),
fallback_prompt=_get_autopilot_fallback_prompt(start_type),
template_vars={
"users_information": users_information,
"business_understanding": business_understanding,
"recent_copilot_emails": recent_copilot_emails,
"recent_session_summaries": recent_session_summaries,
"recent_manual_sessions": recent_manual_sessions,
"beta_application_context": beta_application_context,
},
)

File diff suppressed because it is too large Load Diff

View File

@@ -38,9 +38,9 @@ from backend.copilot.response_model import (
StreamToolOutputAvailable,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
_resolve_system_prompt,
client,
_get_openai_client,
config,
)
from backend.copilot.tools import execute_tool, get_available_tools
@@ -89,7 +89,7 @@ async def _compress_session_messages(
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
client=_get_openai_client(),
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
@@ -160,7 +160,7 @@ async def stream_chat_completion_baseline(
session = await upsert_chat_session(session)
# Generate title for new sessions
if is_user_message and session.is_manual and not session.title:
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
@@ -177,20 +177,16 @@ async def stream_chat_completion_baseline(
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
base_system_prompt, _ = await _resolve_system_prompt(
session,
user_id,
has_conversation_history=False,
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
base_system_prompt, _ = await _resolve_system_prompt(
session,
user_id=None,
has_conversation_history=True,
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement(session)
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
@@ -203,7 +199,7 @@ async def stream_chat_completion_baseline(
if msg.role in ("user", "assistant") and msg.content:
openai_messages.append({"role": msg.role, "content": msg.content})
tools = get_available_tools(session)
tools = get_available_tools()
yield StreamStart(messageId=message_id, sessionId=session_id)
@@ -239,7 +235,7 @@ async def stream_chat_completion_baseline(
)
if tools:
create_kwargs["tools"] = tools
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
response = await _get_openai_client().chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""

View File

@@ -65,18 +65,6 @@ class ChatConfig(BaseSettings):
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
langfuse_autopilot_nightly_prompt_name: str = Field(
default="CoPilot Nightly",
description="Langfuse prompt name for nightly Autopilot sessions",
)
langfuse_autopilot_callback_prompt_name: str = Field(
default="CoPilot Callback",
description="Langfuse prompt name for callback Autopilot sessions",
)
langfuse_autopilot_invite_cta_prompt_name: str = Field(
default="CoPilot Beta Invite CTA",
description="Langfuse prompt name for beta invite CTA Autopilot sessions",
)
langfuse_prompt_cache_ttl: int = Field(
default=300,
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",

View File

@@ -11,6 +11,8 @@ from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from e2b import AsyncSandbox
@@ -82,6 +84,17 @@ def resolve_sandbox_path(path: str) -> str:
return normalized
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped :class:`WorkspaceManager`.
Placed here (rather than in ``tools/workspace_files``) so that modules
like ``sdk/file_ref`` can import it without triggering the heavy
``tools/__init__`` import chain.
"""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.

View File

@@ -8,48 +8,19 @@ from typing import Any
from prisma.errors import UniqueViolationError
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.models import ChatSessionCallbackToken as PrismaChatSessionCallbackToken
from prisma.types import (
ChatMessageCreateInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from .model import ChatMessage, ChatSession, ChatSessionInfo
from .session_types import ChatSessionStartType
logger = logging.getLogger(__name__)
_UNSET = object()
class ChatSessionCallbackTokenInfo(BaseModel):
id: str
user_id: str
source_session_id: str | None = None
callback_session_message: str
expires_at: datetime
consumed_at: datetime | None = None
consumed_session_id: str | None = None
@classmethod
def from_db(
cls,
token: PrismaChatSessionCallbackToken,
) -> "ChatSessionCallbackTokenInfo":
return cls(
id=token.id,
user_id=token.userId,
source_session_id=token.sourceSessionId,
callback_session_message=token.callbackSessionMessage,
expires_at=token.expiresAt,
consumed_at=token.consumedAt,
consumed_session_id=token.consumedSessionId,
)
async def get_chat_session(session_id: str) -> ChatSession | None:
@@ -61,103 +32,9 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
return ChatSession.from_db(session) if session else None
async def has_recent_manual_message(user_id: str, since: datetime) -> bool:
message = await PrismaChatMessage.prisma().find_first(
where={
"role": "user",
"createdAt": {"gte": since},
"Session": {
"is": {
"userId": user_id,
"startType": ChatSessionStartType.MANUAL.value,
}
},
}
)
return message is not None
async def has_session_since(user_id: str, since: datetime) -> bool:
session = await PrismaChatSession.prisma().find_first(
where={"userId": user_id, "createdAt": {"gte": since}}
)
return session is not None
async def session_exists_for_execution_tag(user_id: str, execution_tag: str) -> bool:
session = await PrismaChatSession.prisma().find_first(
where={"userId": user_id, "executionTag": execution_tag}
)
return session is not None
async def get_manual_chat_sessions_since(
user_id: str,
since_utc: datetime,
limit: int,
) -> list[ChatSessionInfo]:
sessions = await PrismaChatSession.prisma().find_many(
where={
"userId": user_id,
"startType": ChatSessionStartType.MANUAL.value,
"updatedAt": {"gte": since_utc},
},
order={"updatedAt": "desc"},
take=limit,
)
return [ChatSessionInfo.from_db(session) for session in sessions]
async def get_chat_messages_since(
session_id: str,
since_utc: datetime,
) -> list[ChatMessage]:
messages = await PrismaChatMessage.prisma().find_many(
where={
"sessionId": session_id,
"createdAt": {"gte": since_utc},
},
order={"sequence": "asc"},
)
return [ChatMessage.from_db(message) for message in messages]
def _build_chat_message_create_input(
*,
session_id: str,
sequence: int,
now: datetime,
msg: dict[str, Any],
) -> ChatMessageCreateInput:
data: ChatMessageCreateInput = {
"sessionId": session_id,
"role": msg["role"],
"sequence": sequence,
"createdAt": now,
}
if msg.get("content") is not None:
data["content"] = sanitize_string(msg["content"])
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = sanitize_string(msg["refusal"])
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
return data
async def create_chat_session(
session_id: str,
user_id: str,
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
execution_tag: str | None = None,
session_config: dict[str, Any] | None = None,
) -> ChatSessionInfo:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
@@ -166,9 +43,6 @@ async def create_chat_session(
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
startType=start_type.value,
executionTag=execution_tag,
sessionConfig=SafeJson(session_config or {}),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
@@ -182,19 +56,9 @@ async def update_chat_session(
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
start_type: ChatSessionStartType | None = None,
execution_tag: str | None | object = _UNSET,
session_config: dict[str, Any] | None = None,
completion_report: dict[str, Any] | None | object = _UNSET,
completion_report_repair_count: int | None = None,
completion_report_repair_queued_at: datetime | None | object = _UNSET,
completed_at: datetime | None | object = _UNSET,
notification_email_sent_at: datetime | None | object = _UNSET,
notification_email_skipped_at: datetime | None | object = _UNSET,
) -> ChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
should_clear_completion_report = completion_report is None
if credentials is not None:
data["credentials"] = SafeJson(credentials)
@@ -208,41 +72,12 @@ async def update_chat_session(
data["totalCompletionTokens"] = total_completion_tokens
if title is not None:
data["title"] = title
if start_type is not None:
data["startType"] = start_type.value
if execution_tag is not _UNSET:
data["executionTag"] = execution_tag
if session_config is not None:
data["sessionConfig"] = SafeJson(session_config)
if completion_report is not _UNSET and completion_report is not None:
data["completionReport"] = SafeJson(completion_report)
if completion_report_repair_count is not None:
data["completionReportRepairCount"] = completion_report_repair_count
if completion_report_repair_queued_at is not _UNSET:
data["completionReportRepairQueuedAt"] = completion_report_repair_queued_at
if completed_at is not _UNSET:
data["completedAt"] = completed_at
if notification_email_sent_at is not _UNSET:
data["notificationEmailSentAt"] = notification_email_sent_at
if notification_email_skipped_at is not _UNSET:
data["notificationEmailSkippedAt"] = notification_email_skipped_at
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": {"order_by": {"sequence": "asc"}}},
)
if should_clear_completion_report:
await db.execute_raw_with_schema(
'UPDATE {schema_prefix}"ChatSession" SET "completionReport" = NULL WHERE "id" = $1',
session_id,
)
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": {"order_by": {"sequence": "asc"}}},
)
return ChatSession.from_db(session) if session else None
@@ -352,15 +187,37 @@ async def add_chat_messages_batch(
now = datetime.now(UTC)
async with db.transaction() as tx:
messages_data = [
_build_chat_message_create_input(
session_id=session_id,
sequence=start_sequence + i,
now=now,
msg=msg,
)
for i, msg in enumerate(messages)
]
# Build all message data
messages_data = []
for i, msg in enumerate(messages):
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
# Note: create_many doesn't support nested creates, use sessionId directly
data: ChatMessageCreateInput = {
"sessionId": session_id,
"role": msg["role"],
"sequence": start_sequence + i,
"createdAt": now,
}
# Add optional string fields — sanitize to strip
# PostgreSQL-incompatible control characters.
if msg.get("content") is not None:
data["content"] = sanitize_string(msg["content"])
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = sanitize_string(msg["refusal"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
messages_data.append(data)
# Run create_many and session update in parallel within transaction
# Both use the same timestamp for consistency
@@ -399,14 +256,10 @@ async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
with_auto: bool = False,
) -> list[ChatSessionInfo]:
"""Get chat sessions for a user, ordered by most recent."""
prisma_sessions = await PrismaChatSession.prisma().find_many(
where={
"userId": user_id,
**({} if with_auto else {"startType": ChatSessionStartType.MANUAL.value}),
},
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
@@ -414,88 +267,9 @@ async def get_user_chat_sessions(
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
async def get_pending_notification_chat_sessions(
limit: int = 200,
) -> list[ChatSessionInfo]:
sessions = await PrismaChatSession.prisma().find_many(
where={
"startType": {"not": ChatSessionStartType.MANUAL.value},
"notificationEmailSentAt": None,
"notificationEmailSkippedAt": None,
},
order={"updatedAt": "asc"},
take=limit,
)
return [ChatSessionInfo.from_db(session) for session in sessions]
async def get_pending_notification_chat_sessions_for_user(
user_id: str,
limit: int = 200,
) -> list[ChatSessionInfo]:
sessions = await PrismaChatSession.prisma().find_many(
where={
"userId": user_id,
"startType": {"not": ChatSessionStartType.MANUAL.value},
"notificationEmailSentAt": None,
"notificationEmailSkippedAt": None,
},
order={"updatedAt": "asc"},
take=limit,
)
return [ChatSessionInfo.from_db(session) for session in sessions]
async def get_recent_sent_email_chat_sessions(
user_id: str,
limit: int,
) -> list[ChatSessionInfo]:
sessions = await PrismaChatSession.prisma().find_many(
where={
"userId": user_id,
"startType": {"not": ChatSessionStartType.MANUAL.value},
"notificationEmailSentAt": {"not": None},
},
order={"notificationEmailSentAt": "desc"},
take=max(limit * 3, limit),
)
return [
session_info
for session_info in (ChatSessionInfo.from_db(session) for session in sessions)
if session_info.notification_email_sent_at and session_info.completion_report
][:limit]
async def get_recent_completion_report_chat_sessions(
user_id: str,
limit: int,
) -> list[ChatSessionInfo]:
sessions = await PrismaChatSession.prisma().find_many(
where={
"userId": user_id,
"startType": {"not": ChatSessionStartType.MANUAL.value},
},
order={"updatedAt": "desc"},
take=max(limit * 5, 10),
)
return [
session_info
for session_info in (ChatSessionInfo.from_db(session) for session in sessions)
if session_info.completion_report is not None
][:limit]
async def get_user_session_count(
user_id: str,
with_auto: bool = False,
) -> int:
async def get_user_session_count(user_id: str) -> int:
"""Get the total number of chat sessions for a user."""
return await PrismaChatSession.prisma().count(
where={
"userId": user_id,
**({} if with_auto else {"startType": ChatSessionStartType.MANUAL.value}),
}
)
return await PrismaChatSession.prisma().count(where={"userId": user_id})
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
@@ -585,42 +359,3 @@ async def update_tool_message_content(
f"tool_call_id {tool_call_id}: {e}"
)
return False
async def create_chat_session_callback_token(
user_id: str,
source_session_id: str,
callback_session_message: str,
expires_at: datetime,
) -> ChatSessionCallbackTokenInfo:
token = await PrismaChatSessionCallbackToken.prisma().create(
data={
"userId": user_id,
"sourceSessionId": source_session_id,
"callbackSessionMessage": callback_session_message,
"expiresAt": expires_at,
}
)
return ChatSessionCallbackTokenInfo.from_db(token)
async def get_chat_session_callback_token(
token_id: str,
) -> ChatSessionCallbackTokenInfo | None:
token = await PrismaChatSessionCallbackToken.prisma().find_unique(
where={"id": token_id}
)
return ChatSessionCallbackTokenInfo.from_db(token) if token else None
async def mark_chat_session_callback_token_consumed(
token_id: str,
consumed_session_id: str,
) -> None:
await PrismaChatSessionCallbackToken.prisma().update(
where={"id": token_id},
data={
"consumedAt": datetime.now(UTC),
"consumedSessionId": consumed_session_id,
},
)

View File

@@ -21,7 +21,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
)
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel, Field
from pydantic import BaseModel
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async
@@ -29,11 +29,6 @@ from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from .config import ChatConfig
from .session_types import (
ChatSessionConfig,
ChatSessionStartType,
StoredCompletionReport,
)
logger = logging.getLogger(__name__)
config = ChatConfig()
@@ -85,20 +80,11 @@ class ChatSessionInfo(BaseModel):
user_id: str
title: str | None = None
usage: list[Usage]
credentials: dict[str, dict] = Field(default_factory=dict)
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime
updated_at: datetime
successful_agent_runs: dict[str, int] = Field(default_factory=dict)
successful_agent_schedules: dict[str, int] = Field(default_factory=dict)
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL
execution_tag: str | None = None
session_config: ChatSessionConfig = Field(default_factory=ChatSessionConfig)
completion_report: StoredCompletionReport | None = None
completion_report_repair_count: int = 0
completion_report_repair_queued_at: datetime | None = None
completed_at: datetime | None = None
notification_email_sent_at: datetime | None = None
notification_email_skipped_at: datetime | None = None
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
@@ -111,8 +97,6 @@ class ChatSessionInfo(BaseModel):
successful_agent_schedules = _parse_json_field(
prisma_session.successfulAgentSchedules, default={}
)
session_config = _parse_json_field(prisma_session.sessionConfig, default={})
completion_report = _parse_json_field(prisma_session.completionReport)
# Calculate usage from token counts
usage = []
@@ -126,20 +110,6 @@ class ChatSessionInfo(BaseModel):
)
)
parsed_session_config = ChatSessionConfig.model_validate(session_config or {})
parsed_completion_report = None
if isinstance(completion_report, dict):
try:
parsed_completion_report = StoredCompletionReport.model_validate(
completion_report
)
except Exception:
logger.warning(
"Invalid completionReport payload on session %s",
prisma_session.id,
exc_info=True,
)
return cls(
session_id=prisma_session.id,
user_id=prisma_session.userId,
@@ -150,15 +120,6 @@ class ChatSessionInfo(BaseModel):
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
start_type=ChatSessionStartType(str(prisma_session.startType)),
execution_tag=prisma_session.executionTag,
session_config=parsed_session_config,
completion_report=parsed_completion_report,
completion_report_repair_count=prisma_session.completionReportRepairCount,
completion_report_repair_queued_at=prisma_session.completionReportRepairQueuedAt,
completed_at=prisma_session.completedAt,
notification_email_sent_at=prisma_session.notificationEmailSentAt,
notification_email_skipped_at=prisma_session.notificationEmailSkippedAt,
)
@@ -166,13 +127,7 @@ class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(
cls,
user_id: str,
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
execution_tag: str | None = None,
session_config: ChatSessionConfig | None = None,
) -> Self:
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -182,9 +137,6 @@ class ChatSession(ChatSessionInfo):
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
start_type=start_type,
execution_tag=execution_tag,
session_config=session_config or ChatSessionConfig(),
)
@classmethod
@@ -200,16 +152,6 @@ class ChatSession(ChatSessionInfo):
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
)
@property
def is_manual(self) -> bool:
return self.start_type == ChatSessionStartType.MANUAL
def allows_tool(self, tool_name: str) -> bool:
return self.session_config.allows_tool(tool_name)
def disables_tool(self, tool_name: str) -> bool:
return self.session_config.disables_tool(tool_name)
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
@@ -582,9 +524,6 @@ async def _save_session_to_db(
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
start_type=session.start_type,
execution_tag=session.execution_tag,
session_config=session.session_config.model_dump(mode="json"),
)
existing_message_count = 0
@@ -600,19 +539,6 @@ async def _save_session_to_db(
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
start_type=session.start_type,
execution_tag=session.execution_tag,
session_config=session.session_config.model_dump(mode="json"),
completion_report=(
session.completion_report.model_dump(mode="json")
if session.completion_report
else None
),
completion_report_repair_count=session.completion_report_repair_count,
completion_report_repair_queued_at=session.completion_report_repair_queued_at,
completed_at=session.completed_at,
notification_email_sent_at=session.notification_email_sent_at,
notification_email_skipped_at=session.notification_email_skipped_at,
)
# Add new messages (only those after existing count)
@@ -675,13 +601,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
return session
async def create_chat_session(
user_id: str,
start_type: ChatSessionStartType = ChatSessionStartType.MANUAL,
execution_tag: str | None = None,
session_config: ChatSessionConfig | None = None,
initial_messages: list[ChatMessage] | None = None,
) -> ChatSession:
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
Raises:
@@ -689,30 +609,14 @@ async def create_chat_session(
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(
user_id,
start_type=start_type,
execution_tag=execution_tag,
session_config=session_config,
)
if initial_messages:
session.messages.extend(initial_messages)
session = ChatSession.new(user_id)
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
start_type=session.start_type,
execution_tag=session.execution_tag,
session_config=session.session_config.model_dump(mode="json"),
)
if session.messages:
await _save_session_to_db(
session,
0,
skip_existence_check=True,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")
raise DatabaseError(
@@ -732,7 +636,6 @@ async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
with_auto: bool = False,
) -> tuple[list[ChatSessionInfo], int]:
"""Get chat sessions for a user from the database with total count.
@@ -741,16 +644,8 @@ async def get_user_sessions(
number of sessions for the user (not just the current page).
"""
db = chat_db()
sessions = await db.get_user_chat_sessions(
user_id,
limit,
offset,
with_auto=with_auto,
)
total_count = await db.get_user_session_count(
user_id,
with_auto=with_auto,
)
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
total_count = await db.get_user_session_count(user_id)
return sessions, total_count

View File

@@ -19,7 +19,6 @@ from .model import (
get_chat_session,
upsert_chat_session,
)
from .session_types import ChatSessionConfig, ChatSessionStartType
messages = [
ChatMessage(content="Hello, how are you?", role="user"),
@@ -47,15 +46,7 @@ messages = [
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_serialization_deserialization():
s = ChatSession.new(
user_id="abc123",
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
execution_tag="autopilot-nightly:2026-03-13",
session_config=ChatSessionConfig(
extra_tools=["completion_report"],
disabled_tools=["edit_agent"],
),
)
s = ChatSession.new(user_id="abc123")
s.messages = messages
s.usage = [Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)]
serialized = s.model_dump_json()

View File

@@ -6,7 +6,7 @@ handling the distinction between:
- Local mode vs E2B mode (storage/filesystem differences)
"""
from backend.copilot.tools import iter_available_tools
from backend.copilot.tools import TOOL_REGISTRY
# Shared technical notes that apply to both SDK and baseline modes
_SHARED_TOOL_NOTES = """\
@@ -52,11 +52,43 @@ Examples:
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Type coercion**: The platform automatically coerces expanded string values
to match the block's expected input types. For example, if a block expects
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
an @@agptfile: expansion), the string will be parsed into the correct type.
**Structured data**: When the **entire** argument value is a single file
reference (no surrounding text), the platform automatically parses the file
content based on its extension or MIME type. Supported formats: JSON, JSONL,
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
the rows will be parsed into `list[list[str]]` automatically. If the format is
unrecognised or parsing fails, the content is returned as a plain string.
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
**Type coercion**: The platform also coerces expanded values to match the
block's expected input types. For example, if a block expects `list[list[str]]`
and the expanded value is a JSON string, it will be parsed into the correct type.
### Media file inputs (format: "file")
Some block inputs accept media files — their schema shows `"format": "file"`.
These fields accept:
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
for large files (images, videos, PDFs). The platform passes the reference
directly to the block without reading the content into memory.
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
small files only.
When a block input has `format: "file"`, **pass the `workspace://` URI
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
payloads in tool arguments and preserves binary content (images, videos)
that would be corrupted by text encoding.
Example — committing an image file to GitHub:
```json
{
"files": [{
"path": "docs/hero.png",
"content": "workspace://abc123#image/png",
"operation": "upsert"
}]
}
```
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
@@ -161,7 +193,7 @@ def _get_cloud_sandbox_supplement() -> str:
)
def _generate_tool_documentation(session=None) -> str:
def _generate_tool_documentation() -> str:
"""Auto-generate tool documentation from TOOL_REGISTRY.
NOTE: This is ONLY used in baseline mode (direct OpenAI API).
@@ -177,7 +209,11 @@ def _generate_tool_documentation(session=None) -> str:
docs = "\n## AVAILABLE TOOLS\n\n"
# Sort tools alphabetically for consistent output
for name, tool in sorted(iter_available_tools(session), key=lambda item: item[0]):
# Filter by is_available to match get_available_tools() behavior
for name in sorted(TOOL_REGISTRY.keys()):
tool = TOOL_REGISTRY[name]
if not tool.is_available:
continue
schema = tool.as_openai_tool()
desc = schema["function"].get("description", "No description available")
# Format as bullet list with tool name in code style
@@ -205,7 +241,7 @@ def get_sdk_supplement(use_e2b: bool, cwd: str = "") -> str:
return _get_local_storage_supplement(cwd)
def get_baseline_supplement(session=None) -> str:
def get_baseline_supplement() -> str:
"""Get the supplement for baseline mode (direct OpenAI API).
Baseline mode INCLUDES auto-generated tool documentation because the
@@ -215,5 +251,5 @@ def get_baseline_supplement(session=None) -> str:
Returns:
The supplement string to append to the system prompt
"""
tool_docs = _generate_tool_documentation(session)
tool_docs = _generate_tool_documentation()
return tool_docs + _SHARED_TOOL_NOTES

View File

@@ -3,12 +3,45 @@
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
circular import cycle::
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
function scope without a larger refactor (moving tool-name registration
to a separate lightweight module). The lazy-import pattern here is the
least invasive way to break the cycle while keeping module-level constants
intact.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
from typing import Any
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
# pair so new exports can be added without touching __getattr__ itself.
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
}
def __getattr__(name: str) -> Any:
entry = _LAZY_IMPORTS.get(name)
if entry is not None:
module_path, attr = entry
import importlib
module = importlib.import_module(module_path, package=__name__)
value = getattr(module, attr)
globals()[name] = value
return value
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -41,12 +41,20 @@ from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
get_workspace_manager,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
from backend.util.file_content_parser import (
BINARY_FORMATS,
MIME_TO_FORMAT,
PARSE_EXCEPTIONS,
infer_format_from_uri,
parse_file_content,
)
from backend.util.type import MediaFileType
class FileRefExpansionError(Exception):
@@ -74,6 +82,8 @@ _FILE_REF_RE = re.compile(
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
# Maximum raw byte size for bare ref structured parsing (10 MB).
_MAX_BARE_REF_BYTES = 10_000_000
@dataclass
@@ -83,6 +93,11 @@ class FileRef:
end_line: int | None # 1-indexed, inclusive
# ---------------------------------------------------------------------------
# Public API (top-down: main functions first, helpers below)
# ---------------------------------------------------------------------------
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
@@ -104,17 +119,6 @@ def parse_file_ref(text: str) -> FileRef | None:
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
@@ -130,27 +134,47 @@ async def read_file_bytes(
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
manager = await get_workspace_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
data = await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
except (PermissionError, OSError) as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
except (AttributeError, TypeError, RuntimeError) as exc:
# AttributeError/TypeError: workspace manager returned an
# unexpected type or interface; RuntimeError: async runtime issues.
logger.warning("Unexpected error reading %s: %s", plain, exc)
raise ValueError(f"Failed to read {plain}: {exc}") from exc
# NOTE: Workspace API does not support pre-read size checks;
# the full file is loaded before the size guard below.
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
)
return data
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
# Read with a one-byte overshoot to detect files that exceed the limit
# without a separate os.path.getsize call (avoids TOCTOU race).
with open(resolved, "rb") as fh:
return fh.read()
data = fh.read(_MAX_BARE_REF_BYTES + 1)
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
f"limit {_MAX_BARE_REF_BYTES})"
)
return data
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
except OSError as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
@@ -162,9 +186,33 @@ async def read_file_bytes(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
data = bytes(await sandbox.files.read(remote, format="bytes"))
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
except Exception as exc:
# E2B SDK raises SandboxException subclasses (NotFoundException,
# TimeoutException, NotEnoughSpaceException, etc.) which don't
# inherit from standard exceptions. Import lazily to avoid a
# hard dependency on e2b at module level.
try:
from e2b.exceptions import SandboxException # noqa: PLC0415
if isinstance(exc, SandboxException):
raise ValueError(
f"Failed to read from sandbox: {plain}: {exc}"
) from exc
except ImportError:
pass
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
# so they surface as real bugs rather than being silently masked.
raise
# NOTE: E2B sandbox API does not support pre-read size checks;
# the full file is loaded before the size guard below.
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
)
return data
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
@@ -178,15 +226,13 @@ async def resolve_file_ref(
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
session: ChatSession,
*,
raise_on_error: bool = False,
) -> str:
@@ -232,6 +278,9 @@ async def expand_file_refs_in_string(
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
# remaining == 0 means the budget was exactly exhausted by the
# previous ref. The elif below (len > remaining) won't catch
# this since 0 > 0 is false, so we need the <= 0 check.
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
@@ -252,13 +301,31 @@ async def expand_file_refs_in_string(
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
session: ChatSession,
*,
input_schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
**Bare references** (the entire argument value is a single
``@@agptfile:...`` token with no surrounding text) are resolved and then
parsed according to the file's extension or MIME type. See
:mod:`backend.util.file_content_parser` for the full list of supported
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
When *input_schema* is provided and the target property has
``"type": "string"``, structured parsing is skipped — the raw file content
is returned as a plain string so blocks receive the original text.
If the format is unrecognised or parsing fails, the content is returned as
a plain string (the fallback).
**Embedded references** (``@@agptfile:`` mixed with other text) always
produce a plain string — structured parsing only applies to bare refs.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
@@ -267,15 +334,382 @@ async def expand_file_refs_in_args(
if not args:
return args
async def _expand(value: Any) -> Any:
properties = (input_schema or {}).get("properties", {})
async def _expand(
value: Any,
*,
prop_schema: dict[str, Any] | None = None,
) -> Any:
"""Recursively expand a single argument value.
Strings are checked for ``@@agptfile:`` references and expanded
(bare refs get structured parsing; embedded refs get inline
substitution). Dicts and lists are traversed recursively,
threading the corresponding sub-schema from *prop_schema* so
that nested fields also receive correct type-aware expansion.
Non-string scalars pass through unchanged.
"""
if isinstance(value, str):
ref = parse_file_ref(value)
if ref is not None:
# MediaFileType fields: return the raw URI immediately —
# no file reading, no format inference, no content parsing.
if _is_media_file_field(prop_schema):
return ref.uri
fmt = infer_format_from_uri(ref.uri)
# Workspace URIs by ID (workspace://abc123) have no extension.
# When the MIME fragment is also missing, fall back to the
# workspace file manager's metadata for format detection.
if fmt is None and ref.uri.startswith("workspace://"):
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
# Not a bare ref — do normal inline expansion.
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
# When the schema says this is an object but doesn't define
# inner properties, skip expansion — the caller (e.g.
# RunBlockTool) will expand with the actual nested schema.
if (
prop_schema is not None
and prop_schema.get("type") == "object"
and "properties" not in prop_schema
):
return value
nested_props = (prop_schema or {}).get("properties", {})
return {
k: await _expand(v, prop_schema=nested_props.get(k))
for k, v in value.items()
}
if isinstance(value, list):
return [await _expand(item) for item in value]
items_schema = (prop_schema or {}).get("items")
return [await _expand(item, prop_schema=items_schema) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
# ---------------------------------------------------------------------------
# Private helpers (used by the public functions above)
# ---------------------------------------------------------------------------
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive).
When the requested range extends beyond the file, a note is appended
so the LLM knows it received the entire remaining content.
"""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
total = len(lines)
s = (start - 1) if start is not None else 0
e = end if end is not None else total
selected = list(itertools.islice(lines, s, e))
result = "".join(selected)
if end is not None and end > total:
result += f"\n[Note: file has only {total} lines]\n"
return result
def _to_str(content: str | bytes) -> str:
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
if isinstance(content, str):
return content
return content.decode("utf-8", errors="replace")
def _check_content_size(content: str | bytes) -> None:
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
(``_expand_bare_ref``) can unify all resolution errors into a single
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
For ``bytes``, the length is the byte count directly. For ``str``,
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
mean the byte size can be up to 4x the character count.
"""
if isinstance(content, bytes):
size = len(content)
else:
char_len = len(content)
# Fast lower bound: UTF-8 byte count >= char count.
# If char count already exceeds the limit, reject immediately
# without allocating an encoded copy.
if char_len > _MAX_BARE_REF_BYTES:
size = char_len # real byte size is even larger
# Fast upper bound: each char is at most 4 UTF-8 bytes.
# If worst-case is still under the limit, skip encoding entirely.
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
return
else:
# Edge case: char count is under limit but multibyte chars
# might push byte count over. Encode to get exact size.
size = len(content.encode("utf-8"))
if size > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large for structured parsing "
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
)
async def _infer_format_from_workspace(
uri: str,
user_id: str | None,
session: ChatSession,
) -> str | None:
"""Look up workspace file metadata to infer the format.
Workspace URIs by ID (``workspace://abc123``) have no file extension.
When the MIME fragment is also absent, we query the workspace file
manager for the file's stored MIME type and original filename.
"""
if not user_id:
return None
try:
ws = parse_workspace_uri(uri)
manager = await get_workspace_manager(user_id, session.session_id)
info = await (
manager.get_file_info(ws.file_ref)
if not ws.is_path
else manager.get_file_info_by_path(ws.file_ref)
)
if info is None:
return None
# Try MIME type first, then filename extension.
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
except (
ValueError,
FileNotFoundError,
OSError,
PermissionError,
AttributeError,
TypeError,
):
# Expected failures: bad URI, missing file, permission denied, or
# workspace manager returning unexpected types. Propagate anything
# else (e.g. programming errors) so they don't get silently swallowed.
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
return None
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
if prop_schema is None:
return False
return (
prop_schema.get("type") == "string"
and prop_schema.get("format") == MediaFileType.string_format
)
async def _expand_bare_ref(
ref: FileRef,
fmt: str | None,
user_id: str | None,
session: ChatSession,
prop_schema: dict[str, Any] | None,
) -> Any:
"""Resolve and parse a bare ``@@agptfile:`` reference.
This is the structured-parsing path: the file is read, optionally parsed
according to *fmt*, and adapted to the target *prop_schema*.
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
Note: MediaFileType fields (format: "file") are handled earlier in
``_expand`` to avoid unnecessary format inference and file I/O.
"""
try:
if fmt is not None and fmt in BINARY_FORMATS:
# Binary formats need raw bytes, not UTF-8 text.
# Line ranges are meaningless for binary formats (parquet/xlsx)
# — ignore them and parse full bytes. Warn so the caller/model
# knows the range was silently dropped.
if ref.start_line is not None or ref.end_line is not None:
logger.warning(
"Line range [%s-%s] ignored for binary format %s (%s); "
"binary formats are always parsed in full.",
ref.start_line,
ref.end_line,
fmt,
ref.uri,
)
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
else:
content = await resolve_file_ref(ref, user_id, session)
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# For known formats this rejects files >10 MB before parsing.
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
# but this check still guards the parsing path which has no char limit.
# _check_content_size raises ValueError, which we unify here just like
# resolution errors above.
try:
_check_content_size(content)
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# When the schema declares this parameter as "string",
# return raw file content — don't parse into a structured
# type that would need json.dumps() serialisation.
expect_string = (prop_schema or {}).get("type") == "string"
if expect_string:
if isinstance(content, bytes):
raise FileRefExpansionError(
f"Cannot use {fmt} file as text input: "
f"binary formats (parquet, xlsx) must be passed "
f"to a block that accepts structured data (list/object), "
f"not a string-typed parameter."
)
return content
if fmt is not None:
# Use strict mode for binary formats so we surface the
# actual error (e.g. missing pyarrow/openpyxl, corrupt
# file) instead of silently returning garbled bytes.
strict = fmt in BINARY_FORMATS
try:
parsed = parse_file_content(content, fmt, strict=strict)
except PARSE_EXCEPTIONS as exc:
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
# Normalize bytes fallback to str so tools never
# receive raw bytes when parsing fails.
if isinstance(parsed, bytes):
parsed = _to_str(parsed)
return _adapt_to_schema(parsed, prop_schema)
# Unknown format — return as plain string, but apply
# the same per-ref character limit used by inline refs
# to prevent injecting unexpectedly large content.
text = _to_str(content)
if len(text) > _MAX_EXPAND_CHARS:
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
return text
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
"""Adapt a parsed file value to better fit the target schema type.
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
that doesn't match the block's expected type, this function converts it to
a more useful representation instead of relying on pydantic's generic
coercion (which can produce awkward results like flattened dicts → lists).
Returns *parsed* unchanged when no adaptation is needed.
"""
if prop_schema is None:
return parsed
target_type = prop_schema.get("type")
# Dict → array: delegate to helper.
if isinstance(parsed, dict) and target_type == "array":
return _adapt_dict_to_array(parsed, prop_schema)
# List → object: delegate to helper (raises for non-tabular lists).
if isinstance(parsed, list) and target_type == "object":
return _adapt_list_to_object(parsed)
# Tabular list → Any (no type): convert to list of dicts.
# Blocks like FindInDictionaryBlock have `input: Any` which produces
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
# for key lookup, but [{col: val}, ...] works with FindInDict's
# list-of-dicts branch (line 195-199 in data_manipulation.py).
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
return _tabular_to_list_of_dicts(parsed)
return parsed
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
"""Adapt a parsed dict to an array-typed field.
Extracts list-valued entries when the target item type is ``array``,
passes through unchanged when item type is ``string`` (lets pydantic error),
or wraps in ``[parsed]`` as a fallback.
"""
items_type = (prop_schema.get("items") or {}).get("type")
if items_type == "array":
# Target is List[List[Any]] — extract list-typed values from the
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
list_values = [v for v in parsed.values() if isinstance(v, list)]
if list_values:
return list_values
if items_type == "string":
# Target is List[str] — wrapping a dict would give [dict]
# which can't coerce to strings. Return unchanged and let
# pydantic surface a clear validation error.
return parsed
# Fallback: wrap in a single-element list so the block gets [dict]
# instead of pydantic flattening keys/values into a flat list.
return [parsed]
def _adapt_list_to_object(parsed: list) -> Any:
"""Adapt a parsed list to an object-typed field.
Converts tabular lists to column-dicts; raises for non-tabular lists.
"""
if _is_tabular(parsed):
return _tabular_to_column_dict(parsed)
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
# be meaningfully coerced to an object. Raise explicitly so callers
# get a clear error rather than pydantic silently wrapping the list.
raise FileRefExpansionError(
"Cannot adapt a non-tabular list to an object-typed field. "
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
)
def _is_tabular(parsed: Any) -> bool:
"""Check if parsed data is in tabular format: [[header], [row1], ...].
Uses isinstance checks because this is a structural type guard on
opaque parser output (Any), not duck typing. A Protocol wouldn't
help here — we need to verify exact list-of-lists shape.
"""
if not isinstance(parsed, list) or len(parsed) < 2:
return False
header = parsed[0]
if not isinstance(header, list) or not header:
return False
if not all(isinstance(h, str) for h in header):
return False
return all(isinstance(row, list) for row in parsed[1:])
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
Ragged rows (fewer columns than the header) get None for missing values.
Extra values beyond the header length are silently dropped.
"""
header = parsed[0]
return [
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
for row in parsed[1:]
]
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
Ragged rows (fewer columns than the header) get None for missing values,
ensuring all columns have equal length.
"""
header = parsed[0]
return {
col: [row[i] if i < len(row) else None for row in parsed[1:]]
for i, col in enumerate(header)
}

View File

@@ -175,6 +175,199 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
assert result["count"] == 42
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — bare ref structured parsing
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_bare_ref_json_returns_parsed_dict():
"""Bare ref to a .json file returns parsed dict, not raw string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "data.json")
with open(json_file, "w") as f:
f.write('{"key": "value", "count": 42}')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{json_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == {"key": "value", "count": 42}
@pytest.mark.asyncio
async def test_bare_ref_csv_returns_parsed_table():
"""Bare ref to a .csv file returns list[list[str]] table."""
with tempfile.TemporaryDirectory() as sdk_cwd:
csv_file = os.path.join(sdk_cwd, "data.csv")
with open(csv_file, "w") as f:
f.write("Name,Score\nAlice,90\nBob,85")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"input": f"@@agptfile:{csv_file}"},
user_id="u1",
session=_make_session(),
)
assert result["input"] == [
["Name", "Score"],
["Alice", "90"],
["Bob", "85"],
]
@pytest.mark.asyncio
async def test_bare_ref_unknown_extension_returns_string():
"""Bare ref to a file with unknown extension returns plain string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
txt_file = os.path.join(sdk_cwd, "readme.txt")
with open(txt_file, "w") as f:
f.write("plain text content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{txt_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == "plain text content"
assert isinstance(result["data"], str)
@pytest.mark.asyncio
async def test_bare_ref_invalid_json_falls_back_to_string():
"""Bare ref to a .json file with invalid JSON falls back to string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "bad.json")
with open(json_file, "w") as f:
f.write("not valid json {{{")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{json_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == "not valid json {{{"
assert isinstance(result["data"], str)
@pytest.mark.asyncio
async def test_embedded_ref_always_returns_string_even_for_json():
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "data.json")
with open(json_file, "w") as f:
f.write('{"key": "value"}')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"prefix @@agptfile:{json_file} suffix"},
user_id="u1",
session=_make_session(),
)
assert isinstance(result["data"], str)
assert result["data"].startswith("prefix ")
assert result["data"].endswith(" suffix")
@pytest.mark.asyncio
async def test_bare_ref_yaml_returns_parsed_dict():
"""Bare ref to a .yaml file returns parsed dict."""
with tempfile.TemporaryDirectory() as sdk_cwd:
yaml_file = os.path.join(sdk_cwd, "config.yaml")
with open(yaml_file, "w") as f:
f.write("name: test\ncount: 42\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"config": f"@@agptfile:{yaml_file}"},
user_id="u1",
session=_make_session(),
)
assert result["config"] == {"name": "test", "count": 42}
@pytest.mark.asyncio
async def test_bare_ref_binary_with_line_range_ignores_range():
"""Bare ref to a binary file (.parquet) with line range parses the full file.
Binary formats (parquet, xlsx) ignore line ranges — the full content is
parsed and the range is silently dropped with a log warning.
"""
try:
import pandas as pd
except ImportError:
pytest.skip("pandas not installed")
try:
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
except ImportError:
pytest.skip("pyarrow not installed")
with tempfile.TemporaryDirectory() as sdk_cwd:
parquet_file = os.path.join(sdk_cwd, "data.parquet")
import io as _io
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
buf = _io.BytesIO()
df.to_parquet(buf, index=False)
with open(parquet_file, "wb") as f:
f.write(buf.getvalue())
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
# Line range [1-2] should be silently ignored for binary formats.
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{parquet_file}[1-2]"},
user_id="u1",
session=_make_session(),
)
# Full file is returned despite the line range.
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
@pytest.mark.asyncio
async def test_bare_ref_toml_returns_parsed_dict():
"""Bare ref to a .toml file returns parsed dict."""
with tempfile.TemporaryDirectory() as sdk_cwd:
toml_file = os.path.join(sdk_cwd, "config.toml")
with open(toml_file, "w") as f:
f.write('name = "test"\ncount = 42\n')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"config": f"@@agptfile:{toml_file}"},
user_id="u1",
session=_make_session(),
)
assert result["config"] == {"name": "test", "count": 42}
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@@ -219,7 +412,7 @@ async def test_read_file_handler_workspace_uri():
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
@@ -276,7 +469,7 @@ async def test_read_file_bytes_workspace_virtual_path():
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@ import subprocess
import sys
import uuid
from collections.abc import AsyncGenerator
from typing import Any, Protocol, cast
from typing import Any, cast
import openai
from claude_agent_sdk import (
@@ -29,6 +29,7 @@ from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
@@ -56,13 +57,12 @@ from ..response_model import (
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_generate_session_title,
_is_langfuse_configured,
_resolve_system_prompt,
)
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .response_adapter import SDKResponseAdapter
@@ -88,10 +88,6 @@ logger = logging.getLogger(__name__)
config = ChatConfig()
class _ClaudeSDKTransport(Protocol):
async def write(self, data: str) -> None: ...
def _setup_langfuse_otel() -> None:
"""Configure OTEL tracing for the Claude Agent SDK → Langfuse.
@@ -141,16 +137,6 @@ def _setup_langfuse_otel() -> None:
_setup_langfuse_otel()
async def _write_multimodal_query(
client: ClaudeSDKClient,
user_message: dict[str, Any],
) -> None:
transport = cast(_ClaudeSDKTransport | None, getattr(client, "_transport", None))
if transport is None:
raise RuntimeError("Claude SDK transport is unavailable for multimodal input")
await transport.write(json.dumps(user_message) + "\n")
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
@@ -579,7 +565,7 @@ async def _prepare_file_attachments(
return empty
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
except Exception:
logger.warning(
"Failed to create workspace manager for file attachments",
@@ -704,7 +690,7 @@ async def stream_chat_completion_sdk(
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and session.is_manual and not session.title:
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
@@ -819,11 +805,7 @@ async def stream_chat_completion_sdk(
e2b_sandbox, (base_system_prompt, _), dl = await asyncio.gather(
_setup_e2b(),
_resolve_system_prompt(
session,
user_id,
has_conversation_history=has_history,
),
_build_system_prompt(user_id, has_conversation_history=has_history),
_fetch_transcript(),
)
@@ -880,7 +862,7 @@ async def stream_chat_completion_sdk(
"Claude Code CLI subscription (requires `claude login`)."
)
mcp_server = create_copilot_mcp_server(session, use_e2b=use_e2b)
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
sdk_model = _resolve_sdk_model()
@@ -894,7 +876,7 @@ async def stream_chat_completion_sdk(
on_compact=compaction.on_compact,
)
allowed = get_copilot_tool_names(session, use_e2b=use_e2b)
allowed = get_copilot_tool_names(use_e2b=use_e2b)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
@@ -995,7 +977,10 @@ async def stream_chat_completion_sdk(
"parent_tool_use_id": None,
"session_id": session_id,
}
await _write_multimodal_query(client, user_msg)
assert client._transport is not None # noqa: SLF001
await client._transport.write( # noqa: SLF001
json.dumps(user_msg) + "\n"
)
# Capture user message in transcript (multimodal)
transcript_builder.append_user(content=content_blocks)
else:

View File

@@ -20,7 +20,7 @@ class _FakeFileInfo:
size_bytes: int
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
class TestPrepareFileAttachments:
@@ -205,29 +205,6 @@ class TestPromptSupplement:
):
assert "`browser_navigate`" in docs
def test_baseline_supplement_respects_session_disabled_tools(self):
"""Session-specific docs should hide disabled tools and include added session tools."""
from backend.copilot.model import ChatSession
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.session_types import (
ChatSessionConfig,
ChatSessionStartType,
)
session = ChatSession.new(
"user-1",
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
session_config=ChatSessionConfig(
extra_tools=["completion_report"],
disabled_tools=["edit_agent"],
),
)
docs = get_baseline_supplement(session)
assert "`completion_report`" in docs
assert "`edit_agent`" not in docs
def test_baseline_supplement_includes_workflows(self):
"""Baseline supplement should include workflow guidance in tool descriptions."""
from backend.copilot.prompting import get_baseline_supplement
@@ -242,13 +219,15 @@ class TestPromptSupplement:
def test_baseline_supplement_completeness(self):
"""All available tools from TOOL_REGISTRY should appear in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import iter_available_tools
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Verify each available registered tool is documented
# (matches _generate_tool_documentation which filters with iter_available_tools)
for tool_name, _ in iter_available_tools():
# (matches _generate_tool_documentation which filters by is_available)
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
assert (
f"`{tool_name}`" in docs
), f"Tool '{tool_name}' missing from baseline supplement"
@@ -298,12 +277,14 @@ class TestPromptSupplement:
def test_baseline_supplement_no_duplicate_tools(self):
"""No tool should appear multiple times in baseline supplement."""
from backend.copilot.prompting import get_baseline_supplement
from backend.copilot.tools import iter_available_tools
from backend.copilot.tools import TOOL_REGISTRY
docs = get_baseline_supplement()
# Count occurrences of each available tool in the entire supplement
for tool_name, _ in iter_available_tools():
for tool_name, tool in TOOL_REGISTRY.items():
if not tool.is_available:
continue
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"

View File

@@ -32,7 +32,7 @@ from backend.copilot.sdk.file_ref import (
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import iter_available_tools
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -338,11 +338,7 @@ def _text_from_mcp_result(result: dict[str, Any]) -> str:
)
def create_copilot_mcp_server(
session: ChatSession,
*,
use_e2b: bool = False,
):
def create_copilot_mcp_server(*, use_e2b: bool = False):
"""Create an in-process MCP server configuration for CoPilot tools.
When *use_e2b* is True, five additional MCP file tools are registered
@@ -351,7 +347,7 @@ def create_copilot_mcp_server(
:func:`get_sdk_disallowed_tools`.
"""
def _truncating(fn, tool_name: str):
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
@@ -365,7 +361,9 @@ def create_copilot_mcp_server(
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
args = await expand_file_refs_in_args(
args, user_id, session, input_schema=input_schema
)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
@@ -391,13 +389,14 @@ def create_copilot_mcp_server(
sdk_tools = []
for tool_name, base_tool in iter_available_tools(session):
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
schema = _build_input_schema(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
schema,
)(_truncating(handler, tool_name, input_schema=schema))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
@@ -479,30 +478,25 @@ DANGEROUS_PATTERNS = [
r"subprocess",
]
# Static tool name list for the non-E2B case (backward compatibility).
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]
def get_copilot_tool_names(
session: ChatSession,
*,
use_e2b: bool = False,
) -> list[str]:
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
equivalents that route to the E2B sandbox.
"""
tool_names = [
f"{MCP_TOOL_PREFIX}{name}" for name, _ in iter_available_tools(session)
]
if not use_e2b:
return [
*tool_names,
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]
return list(COPILOT_TOOL_NAMES)
return [
*tool_names,
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
*_SDK_BUILTIN_ALWAYS,

View File

@@ -3,14 +3,11 @@
import pytest
from backend.copilot.context import get_sdk_cwd
from backend.copilot.model import ChatSession
from backend.copilot.session_types import ChatSessionConfig, ChatSessionStartType
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_copilot_tool_names,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,
@@ -171,20 +168,3 @@ class TestTruncationAndStashIntegration:
text = _text_from_mcp_result(truncated)
assert len(text) < len(big_text)
assert len(str(truncated)) <= _MCP_MAX_CHARS
class TestSessionToolFiltering:
def test_disabled_tools_are_removed_from_sdk_allowed_tools(self):
session = ChatSession.new(
"user-1",
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
session_config=ChatSessionConfig(
extra_tools=["completion_report"],
disabled_tools=["edit_agent"],
),
)
tool_names = get_copilot_tool_names(session)
assert "mcp__copilot__completion_report" in tool_names
assert "mcp__copilot__edit_agent" not in tool_names

View File

@@ -22,16 +22,30 @@ from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.settings import AppEnvironment, Settings
from .config import ChatConfig
from .model import ChatSession, ChatSessionInfo, get_chat_session, upsert_chat_session
from .model import ChatSessionInfo, get_chat_session, upsert_chat_session
logger = logging.getLogger(__name__)
config = ChatConfig()
settings = Settings()
client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
_client: LangfuseAsyncOpenAI | None = None
_langfuse = None
langfuse = get_client()
def _get_openai_client() -> LangfuseAsyncOpenAI:
global _client
if _client is None:
_client = LangfuseAsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
return _client
def _get_langfuse():
global _langfuse
if _langfuse is None:
_langfuse = get_client()
return _langfuse
# Default system prompt used when Langfuse is not configured
# Provides minimal baseline tone and personality - all workflow, tools, and
@@ -64,13 +78,7 @@ def _is_langfuse_configured() -> bool:
)
async def _get_system_prompt_template(
context: str,
*,
prompt_name: str | None = None,
fallback_prompt: str | None = None,
template_vars: dict[str, str] | None = None,
) -> str:
async def _get_system_prompt_template(context: str) -> str:
"""Get the system prompt, trying Langfuse first with fallback to default.
Args:
@@ -79,11 +87,6 @@ async def _get_system_prompt_template(
Returns:
The compiled system prompt string.
"""
resolved_prompt_name = prompt_name or config.langfuse_prompt_name
resolved_template_vars = {
"users_information": context,
**(template_vars or {}),
}
if _is_langfuse_configured():
try:
# Use asyncio.to_thread to avoid blocking the event loop
@@ -95,17 +98,17 @@ async def _get_system_prompt_template(
else "latest"
)
prompt = await asyncio.to_thread(
langfuse.get_prompt,
resolved_prompt_name,
_get_langfuse().get_prompt,
config.langfuse_prompt_name,
label=label,
cache_ttl_seconds=config.langfuse_prompt_cache_ttl,
)
return prompt.compile(**resolved_template_vars)
return prompt.compile(users_information=context)
except Exception as e:
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
# Fallback to default prompt
return (fallback_prompt or DEFAULT_SYSTEM_PROMPT).format(**resolved_template_vars)
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt(
@@ -142,21 +145,6 @@ async def _build_system_prompt(
return compiled, understanding
async def _resolve_system_prompt(
session: ChatSession,
user_id: str | None,
*,
has_conversation_history: bool = False,
) -> tuple[str, Any]:
override = session.session_config.system_prompt_override
if override:
return override, None
return await _build_system_prompt(
user_id,
has_conversation_history=has_conversation_history,
)
async def _generate_session_title(
message: str,
user_id: str | None = None,
@@ -184,7 +172,7 @@ async def _generate_session_title(
"environment": settings.config.app_env.value,
}
response = await client.chat.completions.create(
response = await _get_openai_client().chat.completions.create(
model=config.title_model,
messages=[
{

View File

@@ -1,60 +0,0 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, model_validator
class ChatSessionStartType(str, Enum):
MANUAL = "MANUAL"
AUTOPILOT_NIGHTLY = "AUTOPILOT_NIGHTLY"
AUTOPILOT_CALLBACK = "AUTOPILOT_CALLBACK"
AUTOPILOT_INVITE_CTA = "AUTOPILOT_INVITE_CTA"
class ChatSessionConfig(BaseModel):
system_prompt_override: str | None = None
initial_user_message: str | None = None
initial_assistant_message: str | None = None
extra_tools: list[str] = Field(default_factory=list)
disabled_tools: list[str] = Field(default_factory=list)
def allows_tool(self, tool_name: str) -> bool:
return tool_name in self.extra_tools
def disables_tool(self, tool_name: str) -> bool:
return tool_name in self.disabled_tools
class CompletionReportInput(BaseModel):
thoughts: str
should_notify_user: bool
email_title: str | None = None
email_body: str | None = None
callback_session_message: str | None = None
approval_summary: str | None = None
@model_validator(mode="after")
def validate_notification_fields(self) -> "CompletionReportInput":
if self.should_notify_user:
required_fields = {
"email_title": self.email_title,
"email_body": self.email_body,
"callback_session_message": self.callback_session_message,
}
missing = [
field_name for field_name, value in required_fields.items() if not value
]
if missing:
raise ValueError(
"Missing required notification fields: " + ", ".join(missing)
)
return self
class StoredCompletionReport(CompletionReportInput):
has_pending_approvals: bool
pending_approval_count: int
pending_approval_graph_exec_id: str | None = None
saved_at: datetime

View File

@@ -17,12 +17,11 @@ Subscribers:
import asyncio
import logging
import time
from collections.abc import Awaitable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal, cast
from typing import Any, Literal
import orjson
from pydantic import BaseModel, ConfigDict, Field
from backend.api.model import CopilotCompletionPayload
from backend.data.notification_bus import (
@@ -56,12 +55,6 @@ _listener_sessions: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
SESSION_LOOKUP_RETRY_SECONDS = 0.05
STREAM_REPLAY_COUNT = 1000
STREAM_XREAD_BLOCK_MS = 5000
STREAM_XREAD_COUNT = 100
STALE_SESSION_BUFFER_SECONDS = 300
UNSUBSCRIBE_TIMEOUT_SECONDS = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
@@ -75,24 +68,19 @@ return 0
"""
SessionStatus = Literal["running", "completed", "failed"]
RedisHash = dict[str, str]
RedisStreamMessages = list[tuple[str, list[tuple[str, RedisHash]]]]
class ActiveSession(BaseModel):
@dataclass
class ActiveSession:
"""Represents an active streaming session (metadata only, no in-memory queues)."""
model_config = ConfigDict(frozen=True)
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
turn_id: str = ""
blocking: bool = False # If True, HTTP request is waiting for completion
status: SessionStatus = "running"
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_session_meta_key(session_id: str) -> str:
@@ -105,54 +93,7 @@ def _get_turn_stream_key(turn_id: str) -> str:
return f"{config.turn_stream_prefix}{turn_id}"
async def _redis_hset_mapping(redis: Any, key: str, mapping: RedisHash) -> int:
return await cast(Awaitable[int], redis.hset(key, mapping=mapping))
async def _redis_hgetall(redis: Any, key: str) -> RedisHash:
return cast(
RedisHash,
await cast(Awaitable[dict[str, str]], redis.hgetall(key)),
)
async def _redis_hget(redis: Any, key: str, field: str) -> str | None:
return cast(
str | None,
await cast(Awaitable[str | None], redis.hget(key, field)),
)
async def _redis_xread(
redis: Any,
streams: dict[str, str],
*,
count: int,
block: int | None,
) -> RedisStreamMessages:
return cast(
RedisStreamMessages,
await cast(
Awaitable[RedisStreamMessages],
redis.xread(streams, count=count, block=block),
),
)
async def _redis_complete_session(
redis: Any,
meta_key: str,
status: SessionStatus,
) -> int:
return int(
await cast(
Awaitable[int | str],
redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status),
)
)
def _parse_session_meta(meta: RedisHash, session_id: str = "") -> ActiveSession:
def _parse_session_meta(meta: dict[Any, Any], session_id: str = "") -> ActiveSession:
"""Parse a raw Redis hash into a typed ActiveSession.
Centralises the ``meta.get(...)`` boilerplate so callers don't repeat it.
@@ -166,7 +107,7 @@ def _parse_session_meta(meta: RedisHash, session_id: str = "") -> ActiveSession:
tool_name=meta.get("tool_name", ""),
turn_id=meta.get("turn_id", "") or session_id,
blocking=meta.get("blocking") == "1",
status=cast(SessionStatus, meta.get("status", "running")),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
@@ -229,8 +170,7 @@ async def create_session(
# No need to delete old stream — each turn_id is a fresh UUID
hset_start = time.perf_counter()
await _redis_hset_mapping(
redis,
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"session_id": session_id,
@@ -340,108 +280,6 @@ async def publish_chunk(
return message_id
def _decode_stream_chunk(msg_data: RedisHash) -> StreamBaseResponse | None:
raw_data = msg_data.get("data")
if raw_data is None:
return None
chunk_data = orjson.loads(raw_data)
return _reconstruct_chunk(chunk_data)
async def _replay_messages(
messages: RedisStreamMessages,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
*,
last_message_id: str,
) -> tuple[int, str]:
replayed_count = 0
replay_last_id = last_message_id
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id
try:
chunk = _decode_stream_chunk(msg_data)
if chunk is None:
continue
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as exc:
logger.warning("Failed to replay message: %s", exc)
return replayed_count, replay_last_id
async def _deliver_message_to_queue(
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
chunk: StreamBaseResponse,
*,
last_delivered_id: str,
log_meta: dict[str, Any],
) -> bool:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
return True
except asyncio.TimeoutError:
logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
)
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
logger.error(
f"Cannot deliver overflow error for session {session_id}, queue completely blocked"
)
return False
async def _handle_xread_timeout(
redis: Any,
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> bool:
meta_key = _get_session_meta_key(session_id)
status = await _redis_hget(redis, meta_key, "status")
if status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(f"Timeout delivering finish event for session {session_id}")
return False
try:
await asyncio.wait_for(
subscriber_queue.put(StreamHeartbeat()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(f"Timeout delivering heartbeat for session {session_id}")
return True
async def subscribe_to_session(
session_id: str,
user_id: str | None,
@@ -475,7 +313,7 @@ async def subscribe_to_session(
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta = await _redis_hgetall(redis, meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
@@ -490,8 +328,8 @@ async def subscribe_to_session(
"[TIMING] Session not found on first attempt, retrying after 50ms delay",
extra={"json_fields": {**log_meta}},
)
await asyncio.sleep(SESSION_LOOKUP_RETRY_SECONDS)
meta = await _redis_hgetall(redis, meta_key)
await asyncio.sleep(0.05) # 50ms
meta = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
@@ -536,12 +374,7 @@ async def subscribe_to_session(
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await _redis_xread(
redis,
{stream_key: last_message_id},
block=None,
count=STREAM_REPLAY_COUNT,
)
messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}",
@@ -554,11 +387,22 @@ async def subscribe_to_session(
},
)
replayed_count, replay_last_id = await _replay_messages(
messages,
subscriber_queue,
last_message_id=last_message_id,
)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.info(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
@@ -611,7 +455,7 @@ async def _stream_listener(
session_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict[str, Any] | None = None,
log_meta: dict | None = None,
turn_id: str = "",
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
@@ -655,11 +499,8 @@ async def _stream_listener(
# Short timeout prevents frontend timeout (12s) while waiting for heartbeats (15s)
xread_start = time.perf_counter()
xread_count += 1
messages = await _redis_xread(
redis,
{stream_key: current_id},
block=STREAM_XREAD_BLOCK_MS,
count=STREAM_XREAD_COUNT,
messages = await redis.xread(
{stream_key: current_id}, block=5000, count=100
)
xread_time = (time.perf_counter() - xread_start) * 1000
@@ -691,66 +532,114 @@ async def _stream_listener(
)
if not messages:
if not await _handle_xread_timeout(
redis,
session_id,
subscriber_queue,
):
# Timeout - check if session is still running
meta_key = _get_session_meta_key(session_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
# Stop if session metadata is gone (TTL expired) or status is not "running"
if status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for session {session_id}"
)
break
# Session still running - send heartbeat to keep connection alive
# This prevents frontend timeout (12s) during long-running operations
try:
await asyncio.wait_for(
subscriber_queue.put(StreamHeartbeat()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering heartbeat for session {session_id}"
)
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk = _decode_stream_chunk(msg_data)
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError:
logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
)
# Send overflow error with recovery info
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for session {session_id}, "
"queue completely blocked"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return
except Exception as e:
logger.warning(
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
continue
if chunk is None:
continue
delivered = await _deliver_message_to_queue(
session_id,
subscriber_queue,
chunk,
last_delivered_id=last_delivered_id,
log_meta=log_meta,
)
if delivered:
last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time / 1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return
except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000
@@ -823,16 +712,16 @@ async def mark_session_completed(
Returns:
True if session was newly marked completed, False if already completed/failed
"""
status: SessionStatus = "failed" if error_message else "completed"
status: Literal["completed", "failed"] = "failed" if error_message else "completed"
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
# Resolve turn_id for publishing to the correct stream
meta = await _redis_hgetall(redis, meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
# Atomic compare-and-swap: only update if status is "running"
result = await _redis_complete_session(redis, meta_key, status)
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Session {session_id} already completed/failed, skipping")
@@ -885,18 +774,6 @@ async def mark_session_completed(
f"for session {session_id}: {e}"
)
try:
from backend.copilot.autopilot import handle_non_manual_session_completion
await handle_non_manual_session_completion(session_id)
except Exception as e:
logger.warning(
"Failed to process non-manual completion for session %s: %s",
session_id,
e,
exc_info=True,
)
return True
@@ -911,7 +788,7 @@ async def get_session(session_id: str) -> ActiveSession | None:
"""
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta = await _redis_hgetall(redis, meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
@@ -938,7 +815,7 @@ async def get_session_with_expiry_info(
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta = await _redis_hgetall(redis, meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Metadata expired — we can't resolve turn_id, so check using
@@ -970,7 +847,7 @@ async def get_active_session(
redis = await get_redis_async()
meta_key = _get_session_meta_key(session_id)
meta = await _redis_hgetall(redis, meta_key)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None, "0-0"
@@ -994,9 +871,7 @@ async def get_active_session(
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (datetime.now(timezone.utc) - created_at).total_seconds()
stale_threshold = (
COPILOT_CONSUMER_TIMEOUT_SECONDS + STALE_SESSION_BUFFER_SECONDS
)
stale_threshold = COPILOT_CONSUMER_TIMEOUT_SECONDS + 300 # + 5min buffer
if age_seconds > stale_threshold:
logger.warning(
f"[STALE_SESSION] Auto-completing stale session {session_id[:8]}... "
@@ -1071,11 +946,7 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
}
chunk_type = chunk_data.get("type")
if not isinstance(chunk_type, str):
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
chunk_class = type_to_class.get(chunk_type)
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
if chunk_class is None:
logger.warning(f"Unknown chunk type: {chunk_type}")
@@ -1140,7 +1011,7 @@ async def unsubscribe_from_session(
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(listener_task, timeout=UNSUBSCRIBE_TIMEOUT_SECONDS)
await asyncio.wait_for(listener_task, timeout=5.0)
except asyncio.CancelledError:
# Expected - the task was successfully cancelled
pass

View File

@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .completion_report import CompletionReportTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
@@ -51,12 +50,10 @@ if TYPE_CHECKING:
from backend.copilot.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
SESSION_SCOPED_TOOL_NAMES = {"completion_report"}
# Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"completion_report": CompletionReportTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
@@ -106,38 +103,16 @@ find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
def is_tool_enabled(tool_name: str, session: "ChatSession | None" = None) -> bool:
if tool_name not in TOOL_REGISTRY:
return False
if session is not None and session.disables_tool(tool_name):
return False
if tool_name not in SESSION_SCOPED_TOOL_NAMES:
return True
if session is None:
return False
return session.allows_tool(tool_name)
def iter_available_tools(
session: "ChatSession | None" = None,
) -> list[tuple[str, BaseTool]]:
return [
(tool_name, tool)
for tool_name, tool in TOOL_REGISTRY.items()
if tool.is_available and is_tool_enabled(tool_name, session)
]
def get_available_tools(
session: "ChatSession | None" = None,
) -> list[ChatCompletionToolParam]:
def get_available_tools() -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
"""
return [tool.as_openai_tool() for _, tool in iter_available_tools(session)]
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
]
def get_tool(tool_name: str) -> BaseTool | None:
@@ -153,9 +128,6 @@ async def execute_tool(
tool_call_id: str,
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
if not is_tool_enabled(tool_name, session):
raise ValueError(f"Tool {tool_name} is not enabled for this session")
tool = get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")

View File

@@ -32,6 +32,7 @@ import shutil
import tempfile
from typing import Any
from backend.copilot.context import get_workspace_manager
from backend.copilot.model import ChatSession
from backend.util.request import validate_url_host
@@ -43,7 +44,6 @@ from .models import (
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
@@ -194,7 +194,7 @@ async def _save_browser_state(
),
}
manager = await get_manager(user_id, session.session_id)
manager = await get_workspace_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
@@ -218,7 +218,7 @@ async def _restore_browser_state(
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_manager(user_id, session.session_id)
manager = await get_workspace_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_manager(user_id, session_name)
manager = await get_workspace_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)

View File

@@ -897,7 +897,7 @@ class TestHasLocalSession:
# _save_browser_state
# ---------------------------------------------------------------------------
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
def _make_mock_manager():

View File

@@ -1,83 +0,0 @@
"""Tool for finalizing non-manual Copilot sessions."""
from typing import Any
from backend.copilot.constants import COPILOT_SESSION_PREFIX
from backend.copilot.model import ChatSession
from backend.copilot.session_types import CompletionReportInput
from backend.data.db_accessors import review_db
from .base import BaseTool
from .models import CompletionReportSavedResponse, ErrorResponse, ToolResponseBase
class CompletionReportTool(BaseTool):
@property
def name(self) -> str:
return "completion_report"
@property
def description(self) -> str:
return (
"Finalize a non-manual session after you have finished the work. "
"Use this exactly once at the end of the flow. "
"Summarize what you did, state whether the user should be notified, "
"and provide any email/callback content that should be used."
)
@property
def parameters(self) -> dict[str, Any]:
schema = CompletionReportInput.model_json_schema()
return {
"type": "object",
"properties": schema.get("properties", {}),
"required": [
"thoughts",
"should_notify_user",
"email_title",
"email_body",
"callback_session_message",
"approval_summary",
],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
if session.is_manual:
return ErrorResponse(
message="completion_report is only available in non-manual sessions.",
session_id=session.session_id,
)
try:
report = CompletionReportInput.model_validate(kwargs)
except Exception as exc:
return ErrorResponse(
message="completion_report arguments are invalid.",
error=str(exc),
session_id=session.session_id,
)
pending_approval_count = await review_db().count_pending_reviews_for_graph_exec(
f"{COPILOT_SESSION_PREFIX}{session.session_id}",
session.user_id,
)
if pending_approval_count > 0 and not report.approval_summary:
return ErrorResponse(
message=(
"approval_summary is required because this session has pending approvals."
),
session_id=session.session_id,
)
return CompletionReportSavedResponse(
message="Completion report recorded successfully.",
session_id=session.session_id,
has_pending_approvals=pending_approval_count > 0,
pending_approval_count=pending_approval_count,
)

View File

@@ -1,95 +0,0 @@
from typing import cast
from unittest.mock import AsyncMock, Mock
import pytest
from backend.copilot.model import ChatSession
from backend.copilot.session_types import ChatSessionStartType
from backend.copilot.tools.completion_report import CompletionReportTool
from backend.copilot.tools.models import CompletionReportSavedResponse, ResponseType
@pytest.mark.asyncio
async def test_completion_report_rejects_manual_sessions() -> None:
tool = CompletionReportTool()
session = ChatSession.new("user-1")
response = await tool._execute(
user_id="user-1",
session=session,
thoughts="Wrapped up the session.",
should_notify_user=False,
email_title=None,
email_body=None,
callback_session_message=None,
approval_summary=None,
)
assert response.type == ResponseType.ERROR
assert "non-manual sessions" in response.message
@pytest.mark.asyncio
async def test_completion_report_requires_approval_summary_when_pending(
mocker,
) -> None:
tool = CompletionReportTool()
session = ChatSession.new(
"user-1",
start_type=ChatSessionStartType.AUTOPILOT_NIGHTLY,
)
review_store = Mock()
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=2)
mocker.patch(
"backend.copilot.tools.completion_report.review_db",
return_value=review_store,
)
response = await tool._execute(
user_id="user-1",
session=session,
thoughts="Prepared a recommendation for the user.",
should_notify_user=True,
email_title="Your nightly update",
email_body="I found something worth reviewing.",
callback_session_message="Let's review the next step together.",
approval_summary=None,
)
assert response.type == ResponseType.ERROR
assert "approval_summary is required" in response.message
@pytest.mark.asyncio
async def test_completion_report_succeeds_without_pending_approvals(
mocker,
) -> None:
tool = CompletionReportTool()
session = ChatSession.new(
"user-1",
start_type=ChatSessionStartType.AUTOPILOT_CALLBACK,
)
review_store = Mock()
review_store.count_pending_reviews_for_graph_exec = AsyncMock(return_value=0)
mocker.patch(
"backend.copilot.tools.completion_report.review_db",
return_value=review_store,
)
response = await tool._execute(
user_id="user-1",
session=session,
thoughts="Reviewed the account and prepared a useful follow-up.",
should_notify_user=True,
email_title="Autopilot found something useful",
email_body="I put together a recommendation for you.",
callback_session_message="Open this chat and I will walk you through it.",
approval_summary=None,
)
assert response.type == ResponseType.COMPLETION_REPORT_SAVED
response = cast(CompletionReportSavedResponse, response)
assert response.has_pending_approvals is False
assert response.pending_approval_count == 0

View File

@@ -16,7 +16,6 @@ class ResponseType(str, Enum):
ERROR = "error"
NO_RESULTS = "no_results"
NEED_LOGIN = "need_login"
COMPLETION_REPORT_SAVED = "completion_report_saved"
# Agent discovery & execution
AGENTS_FOUND = "agents_found"
@@ -139,7 +138,7 @@ class NoResultsResponse(ToolResponseBase):
"""Response when no agents found."""
type: ResponseType = ResponseType.NO_RESULTS
suggestions: list[str] = Field(default_factory=list)
suggestions: list[str] = []
name: str = "no_results"
@@ -171,8 +170,8 @@ class AgentDetails(BaseModel):
name: str
description: str
in_library: bool = False
inputs: dict[str, Any] = Field(default_factory=dict)
credentials: list[CredentialsMetaInput] = Field(default_factory=list)
inputs: dict[str, Any] = {}
credentials: list[CredentialsMetaInput] = []
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
trigger_info: dict[str, Any] | None = None
@@ -192,7 +191,7 @@ class UserReadiness(BaseModel):
"""User readiness status."""
has_all_credentials: bool = False
missing_credentials: dict[str, Any] = Field(default_factory=dict)
missing_credentials: dict[str, Any] = {}
ready_to_run: bool = False
@@ -249,14 +248,6 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class CompletionReportSavedResponse(ToolResponseBase):
"""Response for completion_report."""
type: ResponseType = ResponseType.COMPLETION_REPORT_SAVED
has_pending_approvals: bool = False
pending_approval_count: int = 0
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
@@ -445,9 +436,9 @@ class BlockDetails(BaseModel):
id: str
name: str
description: str
inputs: dict[str, Any] = Field(default_factory=dict)
outputs: dict[str, Any] = Field(default_factory=dict)
credentials: list[CredentialsMetaInput] = Field(default_factory=list)
inputs: dict[str, Any] = {}
outputs: dict[str, Any] = {}
credentials: list[CredentialsMetaInput] = []
class BlockDetailsResponse(ToolResponseBase):
@@ -640,7 +631,7 @@ class FolderInfo(BaseModel):
class FolderTreeInfo(FolderInfo):
"""Folder with nested children for tree display."""
children: list["FolderTreeInfo"] = Field(default_factory=list)
children: list["FolderTreeInfo"] = []
class FolderCreatedResponse(ToolResponseBase):
@@ -687,6 +678,6 @@ class AgentsMovedToFolderResponse(ToolResponseBase):
type: ResponseType = ResponseType.AGENTS_MOVED_TO_FOLDER
agent_ids: list[str]
agent_names: list[str] = Field(default_factory=list)
agent_names: list[str] = []
folder_id: str | None = None
count: int = 0

View File

@@ -12,6 +12,7 @@ from backend.copilot.constants import (
COPILOT_SESSION_PREFIX,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
from backend.data.db_accessors import review_db
from backend.data.execution import ExecutionContext
@@ -197,6 +198,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Expand @@agptfile: refs in input_data with the block's input
# schema. The generic _truncating wrapper skips opaque object
# properties (input_data has no declared inner properties in the
# tool schema), so file ref tokens are still intact here.
# Using the block's schema lets us return raw text for string-typed
# fields and parsed structures for list/dict-typed fields.
if input_data:
try:
input_data = await expand_file_refs_in_args(
input_data,
user_id,
session,
input_schema=input_schema,
)
except FileRefExpansionError as exc:
return ErrorResponse(
message=(
f"Failed to resolve file reference: {exc}. "
"Ensure the file exists before referencing it."
),
session_id=session_id,
)
if missing_credentials:
# Return setup requirements response with missing credentials
credentials_fields_info = block.input_schema.get_credentials_fields_info()

View File

@@ -10,11 +10,11 @@ from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_workspace_manager,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.sandbox import make_session_path
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
@@ -218,12 +218,6 @@ def _is_text_mime(mime_type: str) -> bool:
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped WorkspaceManager."""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
async def _resolve_file(
manager: WorkspaceManager,
file_id: str | None,
@@ -386,7 +380,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
files = await manager.list_files(
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
)
@@ -536,7 +530,7 @@ class ReadWorkspaceFileTool(BaseTool):
)
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved
@@ -772,7 +766,7 @@ class WriteWorkspaceFileTool(BaseTool):
try:
await scan_content_safe(content, filename=filename)
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
rec = await manager.write_file(
content=content,
filename=filename,
@@ -899,7 +893,7 @@ class DeleteWorkspaceFileTool(BaseTool):
)
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved

View File

@@ -92,19 +92,6 @@ def user_db():
return user_db
def invited_user_db():
if db.is_connected():
from backend.data import invited_user as _invited_user_db
invited_user_db = _invited_user_db
else:
from backend.util.clients import get_database_manager_async_client
invited_user_db = get_database_manager_async_client()
return invited_user_db
def understanding_db():
if db.is_connected():
from backend.data import understanding as _understanding_db

View File

@@ -79,7 +79,6 @@ from backend.data.graph import (
from backend.data.human_review import (
cancel_pending_reviews_for_execution,
check_approval,
count_pending_reviews_for_graph_exec,
delete_review_by_node_exec_id,
get_or_create_human_review,
get_pending_reviews_for_execution,
@@ -87,7 +86,6 @@ from backend.data.human_review import (
has_pending_reviews_for_graph_exec,
update_review_processed_status,
)
from backend.data.invited_user import list_invited_users_for_auth_users
from backend.data.notifications import (
clear_all_user_notification_batches,
create_or_add_to_user_notification_batch,
@@ -109,7 +107,6 @@ from backend.data.user import (
get_user_email_verification,
get_user_integrations,
get_user_notification_preference,
list_users,
update_user_integrations,
)
from backend.data.workspace import (
@@ -118,7 +115,6 @@ from backend.data.workspace import (
get_or_create_workspace,
get_workspace_file,
get_workspace_file_by_path,
get_workspace_files_by_ids,
list_workspace_files,
soft_delete_workspace_file,
)
@@ -241,7 +237,6 @@ class DatabaseManager(AppService):
# ============ User + Integrations ============ #
get_user_by_id = _(get_user_by_id)
list_users = _(list_users)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
@@ -254,7 +249,6 @@ class DatabaseManager(AppService):
# ============ Human In The Loop ============ #
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
check_approval = _(check_approval)
count_pending_reviews_for_graph_exec = _(count_pending_reviews_for_graph_exec)
delete_review_by_node_exec_id = _(delete_review_by_node_exec_id)
get_or_create_human_review = _(get_or_create_human_review)
get_pending_reviews_for_execution = _(get_pending_reviews_for_execution)
@@ -319,16 +313,12 @@ class DatabaseManager(AppService):
# ============ Workspace ============ #
count_workspace_files = _(count_workspace_files)
create_workspace_file = _(create_workspace_file)
get_workspace_files_by_ids = _(get_workspace_files_by_ids)
get_or_create_workspace = _(get_or_create_workspace)
get_workspace_file = _(get_workspace_file)
get_workspace_file_by_path = _(get_workspace_file_by_path)
list_workspace_files = _(list_workspace_files)
soft_delete_workspace_file = _(soft_delete_workspace_file)
# ============ Invited Users ============ #
list_invited_users_for_auth_users = _(list_invited_users_for_auth_users)
# ============ Understanding ============ #
get_business_understanding = _(get_business_understanding)
upsert_business_understanding = _(upsert_business_understanding)
@@ -338,28 +328,8 @@ class DatabaseManager(AppService):
update_block_optimized_description = _(update_block_optimized_description)
# ============ CoPilot Chat Sessions ============ #
get_chat_messages_since = _(chat_db.get_chat_messages_since)
get_chat_session_callback_token = _(chat_db.get_chat_session_callback_token)
get_chat_session = _(chat_db.get_chat_session)
create_chat_session_callback_token = _(chat_db.create_chat_session_callback_token)
create_chat_session = _(chat_db.create_chat_session)
get_manual_chat_sessions_since = _(chat_db.get_manual_chat_sessions_since)
get_pending_notification_chat_sessions = _(
chat_db.get_pending_notification_chat_sessions
)
get_pending_notification_chat_sessions_for_user = _(
chat_db.get_pending_notification_chat_sessions_for_user
)
get_recent_completion_report_chat_sessions = _(
chat_db.get_recent_completion_report_chat_sessions
)
get_recent_sent_email_chat_sessions = _(chat_db.get_recent_sent_email_chat_sessions)
has_recent_manual_message = _(chat_db.has_recent_manual_message)
has_session_since = _(chat_db.has_session_since)
mark_chat_session_callback_token_consumed = _(
chat_db.mark_chat_session_callback_token_consumed
)
session_exists_for_execution_tag = _(chat_db.session_exists_for_execution_tag)
update_chat_session = _(chat_db.update_chat_session)
add_chat_message = _(chat_db.add_chat_message)
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
@@ -404,18 +374,10 @@ class DatabaseManagerClient(AppServiceClient):
get_marketplace_graphs_for_monitoring = _(d.get_marketplace_graphs_for_monitoring)
# Human In The Loop
count_pending_reviews_for_graph_exec = _(d.count_pending_reviews_for_graph_exec)
has_pending_reviews_for_graph_exec = _(d.has_pending_reviews_for_graph_exec)
# User Emails
get_user_email_by_id = _(d.get_user_email_by_id)
list_users = _(d.list_users)
# CoPilot Chat Sessions
get_recent_completion_report_chat_sessions = _(
d.get_recent_completion_report_chat_sessions
)
get_recent_sent_email_chat_sessions = _(d.get_recent_sent_email_chat_sessions)
# Library
list_library_agents = _(d.list_library_agents)
@@ -471,14 +433,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ User + Integrations ============ #
get_user_by_id = d.get_user_by_id
list_users = d.list_users
get_user_integrations = d.get_user_integrations
update_user_integrations = d.update_user_integrations
# ============ Human In The Loop ============ #
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
check_approval = d.check_approval
count_pending_reviews_for_graph_exec = d.count_pending_reviews_for_graph_exec
delete_review_by_node_exec_id = d.delete_review_by_node_exec_id
get_or_create_human_review = d.get_or_create_human_review
get_pending_reviews_for_execution = d.get_pending_reviews_for_execution
@@ -546,16 +506,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# ============ Workspace ============ #
count_workspace_files = d.count_workspace_files
create_workspace_file = d.create_workspace_file
get_workspace_files_by_ids = d.get_workspace_files_by_ids
get_or_create_workspace = d.get_or_create_workspace
get_workspace_file = d.get_workspace_file
get_workspace_file_by_path = d.get_workspace_file_by_path
list_workspace_files = d.list_workspace_files
soft_delete_workspace_file = d.soft_delete_workspace_file
# ============ Invited Users ============ #
list_invited_users_for_auth_users = d.list_invited_users_for_auth_users
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding
@@ -564,26 +520,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_blocks_needing_optimization = d.get_blocks_needing_optimization
# ============ CoPilot Chat Sessions ============ #
get_chat_messages_since = d.get_chat_messages_since
get_chat_session_callback_token = d.get_chat_session_callback_token
get_chat_session = d.get_chat_session
create_chat_session_callback_token = d.create_chat_session_callback_token
create_chat_session = d.create_chat_session
get_manual_chat_sessions_since = d.get_manual_chat_sessions_since
get_pending_notification_chat_sessions = d.get_pending_notification_chat_sessions
get_pending_notification_chat_sessions_for_user = (
d.get_pending_notification_chat_sessions_for_user
)
get_recent_completion_report_chat_sessions = (
d.get_recent_completion_report_chat_sessions
)
get_recent_sent_email_chat_sessions = d.get_recent_sent_email_chat_sessions
has_recent_manual_message = d.has_recent_manual_message
has_session_since = d.has_session_since
mark_chat_session_callback_token_consumed = (
d.mark_chat_session_callback_token_consumed
)
session_exists_for_execution_tag = d.session_exists_for_execution_tag
update_chat_session = d.update_chat_session
add_chat_message = d.add_chat_message
add_chat_messages_batch = d.add_chat_messages_batch

View File

@@ -342,19 +342,6 @@ async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
return count > 0
async def count_pending_reviews_for_graph_exec(
graph_exec_id: str,
user_id: str,
) -> int:
return await PendingHumanReview.prisma().count(
where={
"userId": user_id,
"graphExecId": graph_exec_id,
"status": ReviewStatus.WAITING,
}
)
async def _resolve_node_id(node_exec_id: str, get_node_execution) -> str:
"""Resolve node_id from a node_exec_id.

View File

@@ -21,8 +21,8 @@ from backend.data.model import User
from backend.data.redis_client import get_redis_async
from backend.data.tally import get_business_understanding_input_from_tally, mask_email
from backend.data.understanding import (
BusinessUnderstandingInput,
merge_business_understanding_data,
parse_business_understanding_input,
)
from backend.data.user import get_user_by_email, get_user_by_id
from backend.executor.cluster_lock import AsyncClusterLock
@@ -63,16 +63,18 @@ class InvitedUserRecord(BaseModel):
@classmethod
def from_db(cls, invited_user: "prisma.models.InvitedUser") -> "InvitedUserRecord":
payload = parse_business_understanding_input(invited_user.tallyUnderstanding)
payload = (
invited_user.tallyUnderstanding
if isinstance(invited_user.tallyUnderstanding, dict)
else None
)
return cls(
id=invited_user.id,
email=invited_user.email,
status=invited_user.status,
auth_user_id=invited_user.authUserId,
name=invited_user.name,
tally_understanding=(
payload.model_dump(mode="json") if payload is not None else None
),
tally_understanding=payload,
tally_status=invited_user.tallyStatus,
tally_computed_at=invited_user.tallyComputedAt,
tally_error=invited_user.tallyError,
@@ -183,13 +185,19 @@ async def _apply_tally_understanding(
invited_user: "prisma.models.InvitedUser",
tx,
) -> None:
input_data = parse_business_understanding_input(invited_user.tallyUnderstanding)
if input_data is None:
if invited_user.tallyUnderstanding is not None:
logger.warning(
"Malformed tallyUnderstanding for invited user %s; skipping",
invited_user.id,
)
if not isinstance(invited_user.tallyUnderstanding, dict):
return
try:
input_data = BusinessUnderstandingInput.model_validate(
invited_user.tallyUnderstanding
)
except Exception:
logger.warning(
"Malformed tallyUnderstanding for invited user %s; skipping",
invited_user.id,
exc_info=True,
)
return
payload = merge_business_understanding_data({}, input_data)
@@ -215,18 +223,6 @@ async def list_invited_users(
return [InvitedUserRecord.from_db(iu) for iu in invited_users], total
async def list_invited_users_for_auth_users(
auth_user_ids: list[str],
) -> list[InvitedUserRecord]:
if not auth_user_ids:
return []
invited_users = await prisma.models.InvitedUser.prisma().find_many(
where={"authUserId": {"in": auth_user_ids}}
)
return [InvitedUserRecord.from_db(invited_user) for invited_user in invited_users]
async def create_invited_user(
email: str, name: Optional[str] = None
) -> InvitedUserRecord:

View File

@@ -31,18 +31,6 @@ def _json_to_list(value: Any) -> list[str]:
return []
def parse_business_understanding_input(
payload: Any,
) -> "BusinessUnderstandingInput | None":
if payload is None:
return None
try:
return BusinessUnderstandingInput.model_validate(payload)
except pydantic.ValidationError:
return None
class BusinessUnderstandingInput(pydantic.BaseModel):
"""Input model for updating business understanding - all fields optional for incremental updates."""

View File

@@ -62,61 +62,6 @@ async def get_user_by_id(user_id: str) -> User:
return User.from_db(user)
async def list_users(
limit: int = 500,
cursor: str | None = None,
) -> list[User]:
try:
kwargs: dict = {
"take": limit,
"order": {"id": "asc"},
}
if cursor is not None:
kwargs["cursor"] = {"id": cursor}
kwargs["skip"] = 1
users = await PrismaUser.prisma().find_many(**kwargs)
return [User.from_db(user) for user in users]
except Exception as e:
raise DatabaseError(f"Failed to list users: {e}") from e
async def search_users(query: str, limit: int = 20) -> list[User]:
normalized_query = query.strip()
if not normalized_query:
return []
try:
users = await PrismaUser.prisma().find_many(
where={
"OR": [
{
"email": {
"contains": normalized_query,
"mode": "insensitive",
}
},
{
"name": {
"contains": normalized_query,
"mode": "insensitive",
}
},
{
"id": {
"contains": normalized_query,
"mode": "insensitive",
}
},
]
},
order={"updatedAt": "desc"},
take=limit,
)
return [User.from_db(user) for user in users]
except Exception as e:
raise DatabaseError(f"Failed to search users for query {query!r}: {e}") from e
async def get_user_email_by_id(user_id: str) -> Optional[str]:
try:
user = await prisma.user.find_unique(where={"id": user_id})

View File

@@ -254,23 +254,6 @@ async def list_workspace_files(
return [WorkspaceFile.from_db(f) for f in files]
async def get_workspace_files_by_ids(
workspace_id: str,
file_ids: list[str],
) -> list[WorkspaceFile]:
if not file_ids:
return []
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": file_ids},
"workspaceId": workspace_id,
"isDeleted": False,
}
)
return [WorkspaceFile.from_db(file) for file in files]
async def count_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,

View File

@@ -24,12 +24,6 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.copilot.autopilot import (
dispatch_nightly_copilot as dispatch_nightly_copilot_async,
)
from backend.copilot.autopilot import (
send_nightly_copilot_emails as send_nightly_copilot_emails_async,
)
from backend.copilot.optimize_blocks import optimize_block_descriptions
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput, GraphInput
@@ -265,16 +259,6 @@ def cleanup_oauth_tokens():
run_async(_cleanup())
def dispatch_nightly_copilot():
"""Dispatch proactive nightly copilot sessions."""
return run_async(dispatch_nightly_copilot_async())
def send_nightly_copilot_emails():
"""Send emails for completed non-manual copilot sessions."""
return run_async(send_nightly_copilot_emails_async())
def execution_accuracy_alerts():
"""Check execution accuracy and send alerts if drops are detected."""
return report_execution_accuracy_alerts()
@@ -420,7 +404,7 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
) -> "GraphExecutionJobInfo":
# Extract timezone from the trigger if it's a CronTrigger
timezone_str = "UTC"
if isinstance(job_obj.trigger, CronTrigger):
if hasattr(job_obj.trigger, "timezone"):
timezone_str = str(job_obj.trigger.timezone)
return GraphExecutionJobInfo(
@@ -635,24 +619,6 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_job(
dispatch_nightly_copilot,
id="dispatch_nightly_copilot",
trigger=CronTrigger(minute="0,30", timezone=ZoneInfo("UTC")),
replace_existing=True,
max_instances=1,
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_job(
send_nightly_copilot_emails,
id="send_nightly_copilot_emails",
trigger=CronTrigger(minute="15,45", timezone=ZoneInfo("UTC")),
replace_existing=True,
max_instances=1,
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
@@ -827,14 +793,6 @@ class Scheduler(AppService):
"""Manually trigger embedding backfill for approved store agents."""
return ensure_embeddings_coverage()
@expose
def execute_dispatch_nightly_copilot(self):
return dispatch_nightly_copilot()
@expose
def execute_send_nightly_copilot_emails(self):
return send_nightly_copilot_emails()
class SchedulerClient(AppServiceClient):
@classmethod

View File

@@ -1,6 +1,5 @@
import logging
import pathlib
from typing import Any
from postmarker.core import PostmarkClient
from postmarker.models.emails import EmailManager
@@ -8,14 +7,12 @@ from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data.notifications import (
AgentRunData,
NotificationDataType_co,
NotificationEventModel,
NotificationTypeOverride,
)
from backend.util.settings import Settings
from backend.util.text import TextFormatter
from backend.util.url import get_frontend_base_url
logger = logging.getLogger(__name__)
settings = Settings()
@@ -49,102 +46,6 @@ class EmailSender:
MAX_EMAIL_CHARS = 5_000_000 # ~5MB buffer
def _get_unsubscribe_link(self, user_unsubscribe_link: str | None) -> str:
return user_unsubscribe_link or f"{get_frontend_base_url()}/profile/settings"
def _format_template_email(
self,
*,
subject_template: str,
content_template: str,
data: Any,
unsubscribe_link: str,
) -> tuple[str, str]:
return self.formatter.format_email(
base_template=self._read_template("templates/base.html.jinja2"),
subject_template=subject_template,
content_template=content_template,
data=data,
unsubscribe_link=unsubscribe_link,
)
def _build_large_output_summary(
self,
data: (
NotificationEventModel[NotificationDataType_co]
| list[NotificationEventModel[NotificationDataType_co]]
),
*,
email_size: int,
base_url: str,
) -> str:
if isinstance(data, list):
if not data:
return (
"⚠️ A notification generated a very large output "
f"({email_size / 1_000_000:.2f} MB)."
)
event = data[0]
else:
event = data
execution_url = (
f"{base_url}/executions/{event.id}" if event.id is not None else None
)
if isinstance(event.data, AgentRunData):
lines = [
f"⚠️ Your agent '{event.data.agent_name}' generated a very large output ({email_size / 1_000_000:.2f} MB).",
"",
f"Execution time: {event.data.execution_time}",
f"Credits used: {event.data.credits_used}",
]
if execution_url is not None:
lines.append(f"View full results: {execution_url}")
return "\n".join(lines)
lines = [
f"⚠️ A notification generated a very large output ({email_size / 1_000_000:.2f} MB).",
]
if execution_url is not None:
lines.extend(["", f"View full results: {execution_url}"])
return "\n".join(lines)
def send_template(
self,
*,
user_email: str,
subject: str,
template_name: str,
data: dict[str, Any] | None = None,
user_unsubscribe_link: str | None = None,
) -> None:
"""Send an email using a named Jinja2 template file.
Unlike ``send_templated`` (which resolves templates via
``NotificationType``), this method accepts a template filename
directly. Both delegate to the shared ``_format_template_email``
+ ``_send_email`` pipeline.
"""
if not self.postmark:
logger.warning("Postmark client not initialized, email not sent")
return
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
_, full_message = self._format_template_email(
subject_template="{{ subject }}",
content_template=self._read_template(f"templates/{template_name}"),
data={"subject": subject, **(data or {})},
unsubscribe_link=unsubscribe_link,
)
self._send_email(
user_email=user_email,
subject=subject,
body=full_message,
user_unsubscribe_link=unsubscribe_link,
)
def send_templated(
self,
notification: NotificationType,
@@ -161,18 +62,21 @@ class EmailSender:
return
template = self._get_template(notification)
base_url = get_frontend_base_url()
unsubscribe_link = self._get_unsubscribe_link(user_unsub_link)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Normalize data
template_data = {"notifications": data} if isinstance(data, list) else data
try:
subject, full_message = self._format_template_email(
subject, full_message = self.formatter.format_email(
base_template=template.base_template,
subject_template=template.subject_template,
content_template=template.body_template,
data=template_data,
unsubscribe_link=unsubscribe_link,
unsubscribe_link=f"{base_url}/profile/settings",
)
except Exception as e:
logger.error(f"Error formatting full message: {e}")
@@ -186,17 +90,20 @@ class EmailSender:
"Sending summary email instead."
)
summary_message = self._build_large_output_summary(
data,
email_size=email_size,
base_url=base_url,
# Create lightweight summary
summary_message = (
f"⚠️ Your agent '{getattr(data, 'agent_name', 'Unknown')}' "
f"generated a very large output ({email_size / 1_000_000:.2f} MB).\n\n"
f"Execution time: {getattr(data, 'execution_time', 'N/A')}\n"
f"Credits used: {getattr(data, 'credits_used', 'N/A')}\n"
f"View full results: {base_url}/executions/{getattr(data, 'id', 'N/A')}"
)
self._send_email(
user_email=user_email,
subject=f"{subject} (Output Too Large)",
body=summary_message,
user_unsubscribe_link=unsubscribe_link,
user_unsubscribe_link=user_unsub_link,
)
return # Skip sending full email
@@ -205,7 +112,7 @@ class EmailSender:
user_email=user_email,
subject=subject,
body=full_message,
user_unsubscribe_link=unsubscribe_link,
user_unsubscribe_link=user_unsub_link,
)
def _get_template(self, notification: NotificationType):
@@ -216,18 +123,17 @@ class EmailSender:
logger.debug(
f"Template full path: {pathlib.Path(__file__).parent / template_path}"
)
base_template = self._read_template("templates/base.html.jinja2")
template = self._read_template(template_path)
base_template_path = "templates/base.html.jinja2"
with open(pathlib.Path(__file__).parent / base_template_path, "r") as file:
base_template = file.read()
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
template = file.read()
return Template(
subject_template=notification_type_override.subject,
body_template=template,
base_template=base_template,
)
def _read_template(self, template_path: str) -> str:
with open(pathlib.Path(__file__).parent / template_path, "r") as file:
return file.read()
def _send_email(
self,
user_email: str,
@@ -238,33 +144,18 @@ class EmailSender:
if not self.postmark:
logger.warning("Email tried to send without postmark configured")
return
sender_email = settings.config.postmark_sender_email
if not sender_email:
logger.warning("postmark_sender_email not configured, email not sent")
return
unsubscribe_link = self._get_unsubscribe_link(user_unsubscribe_link)
logger.debug(f"Sending email to {user_email} with subject {subject}")
self.postmark.emails.send(
From=sender_email,
From=settings.config.postmark_sender_email,
To=user_email,
Subject=subject,
HtmlBody=body,
Headers={
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
"List-Unsubscribe": f"<{unsubscribe_link}>",
},
)
def send_html(
self,
user_email: str,
subject: str,
body: str,
user_unsubscribe_link: str | None = None,
) -> None:
self._send_email(
user_email=user_email,
subject=subject,
body=body,
user_unsubscribe_link=user_unsubscribe_link,
Headers=(
{
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
"List-Unsubscribe": f"<{user_unsubscribe_link}>",
}
if user_unsubscribe_link
else None
),
)

View File

@@ -1,241 +0,0 @@
from types import SimpleNamespace
from typing import Any, cast
from backend.api.test_helpers import override_config
from backend.copilot.autopilot_email import _markdown_to_email_html
from backend.notifications.email import EmailSender, settings
from backend.util.settings import AppEnvironment
def test_markdown_to_email_html_renders_bold_and_italic() -> None:
html = _markdown_to_email_html("**bold** and *italic*")
assert "<strong>bold</strong>" in html
assert "<em>italic</em>" in html
assert 'style="' in html
def test_markdown_to_email_html_renders_links() -> None:
html = _markdown_to_email_html("[click here](https://example.com)")
assert 'href="https://example.com"' in html
assert "click here" in html
assert "color: #7733F5" in html
def test_markdown_to_email_html_renders_bullet_list() -> None:
html = _markdown_to_email_html("- item one\n- item two")
assert "<ul" in html
assert "<li" in html
assert "item one" in html
assert "item two" in html
def test_markdown_to_email_html_handles_empty_input() -> None:
assert _markdown_to_email_html(None) == ""
assert _markdown_to_email_html("") == ""
assert _markdown_to_email_html(" ") == ""
def test_send_template_renders_nightly_copilot_email(mocker) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I found something useful for you.\n\n"
"Open Copilot and I will walk you through it."
),
"cta_url": "https://example.com/copilot?callbackToken=token-1",
"cta_label": "Open Copilot",
},
)
body = send_email.call_args.kwargs["body"]
assert "I found something useful for you." in body
assert "Open Copilot" in body
assert "Approval needed" not in body
assert send_email.call_args.kwargs["user_unsubscribe_link"].endswith(
"/profile/settings"
)
def test_send_template_renders_nightly_copilot_approval_block(mocker) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I prepared a change worth reviewing."
),
"approval_summary_html": _markdown_to_email_html(
"I drafted a follow-up because it matches your recent activity."
),
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
"cta_label": "Review in Copilot",
},
)
body = send_email.call_args.kwargs["body"]
assert "Approval needed" in body
assert "If you want it to happen, please hit approve." in body
assert "Review in Copilot" in body
def test_send_template_renders_nightly_copilot_callback_email(mocker) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot_callback.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I prepared a follow-up based on your recent work."
),
"cta_url": "https://example.com/copilot?callbackToken=token-1",
"cta_label": "Open Copilot",
},
)
body = send_email.call_args.kwargs["body"]
assert "Autopilot picked up where you left off" in body
assert "I prepared a follow-up based on your recent work." in body
def test_send_template_renders_nightly_copilot_callback_approval_block(mocker) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot_callback.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I prepared a follow-up based on your recent work."
),
"approval_summary_html": _markdown_to_email_html(
"I want your approval before I apply the next step."
),
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
"cta_label": "Review in Copilot",
},
)
body = send_email.call_args.kwargs["body"]
assert "Approval needed" in body
assert "I want your approval before I apply the next step." in body
def test_send_template_renders_nightly_copilot_invite_cta_email(mocker) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot_invite_cta.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I put together an example of how Autopilot could help you."
),
"cta_url": "https://example.com/copilot?callbackToken=token-1",
"cta_label": "Try Copilot",
},
)
body = send_email.call_args.kwargs["body"]
assert "Your Autopilot beta access is waiting" in body
assert "I put together an example of how Autopilot could help you." in body
assert "Try Copilot" in body
def test_send_template_renders_nightly_copilot_invite_cta_approval_block(
mocker,
) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot_invite_cta.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I put together an example of how Autopilot could help you."
),
"approval_summary_html": _markdown_to_email_html(
"If this looks useful, approve the next step to try it."
),
"cta_url": "https://example.com/copilot?sessionId=session-1&showAutopilot=1",
"cta_label": "Review in Copilot",
},
)
body = send_email.call_args.kwargs["body"]
assert "Approval needed" in body
assert "If this looks useful, approve the next step to try it." in body
def test_send_template_still_sends_in_production(mocker) -> None:
sender = EmailSender()
sender.postmark = cast(Any, object())
send_email = mocker.patch.object(sender, "_send_email")
with override_config(settings, "app_env", AppEnvironment.PRODUCTION):
sender.send_template(
user_email="user@example.com",
subject="Autopilot update",
template_name="nightly_copilot.html.jinja2",
data={
"email_body_html": _markdown_to_email_html(
"I found something useful for you."
),
"cta_url": "https://example.com/copilot?callbackToken=token-1",
"cta_label": "Open Copilot",
},
)
send_email.assert_called_once()
def test_send_html_uses_default_unsubscribe_link(mocker) -> None:
sender = EmailSender()
send = mocker.Mock()
sender.postmark = cast(Any, SimpleNamespace(emails=SimpleNamespace(send=send)))
mocker.patch(
"backend.notifications.email.get_frontend_base_url",
return_value="https://example.com",
)
with override_config(settings, "postmark_sender_email", "test@example.com"):
sender.send_html(
user_email="user@example.com",
subject="Autopilot update",
body="<p>Hello</p>",
)
headers = send.call_args.kwargs["Headers"]
assert headers["List-Unsubscribe-Post"] == "List-Unsubscribe=One-Click"
assert headers["List-Unsubscribe"] == "<https://example.com/profile/settings>"

View File

@@ -29,6 +29,7 @@
</noscript>
<![endif]-->
<style type="text/css">
/* RESET STYLES */
html,
body {
margin: 0 !important;
@@ -84,6 +85,7 @@
word-break: break-word;
}
/* iOS BLUE LINKS */
a[x-apple-data-detectors] {
color: inherit !important;
text-decoration: none !important;
@@ -93,11 +95,13 @@
line-height: inherit !important;
}
/* ANDROID CENTER FIX */
div[style*="margin: 16px 0;"] {
margin: 0 !important;
}
@media all and (max-width: 639px) {
/* MEDIA QUERIES */
@media all and (max-width:639px) {
.wrapper {
width: 100% !important;
}
@@ -109,8 +113,8 @@
}
.row {
padding-left: 24px !important;
padding-right: 24px !important;
padding-left: 20px !important;
padding-right: 20px !important;
}
.col-mobile {
@@ -132,6 +136,11 @@
float: none !important;
}
.mobile-left {
text-align: center !important;
float: left !important;
}
.mobile-hide {
display: none !important;
}
@@ -146,9 +155,9 @@
max-width: 100% !important;
}
.card-inner {
padding-left: 24px !important;
padding-right: 24px !important;
.ml-btn-container {
width: 100% !important;
max-width: 100% !important;
}
}
</style>
@@ -165,139 +174,170 @@
<title>{{data.title}}</title>
</head>
<body style="margin: 0 !important; padding: 0 !important; background-color: #070629;">
<body style="margin: 0 !important; padding: 0 !important; background-color:#070629;">
<div class="document" role="article" aria-roledescription="email" aria-label lang dir="ltr"
style="background-color: #070629; line-height: 100%; font-size: medium; font-size: max(16px, 1rem);">
style="background-color:#070629; line-height: 100%; font-size:medium; font-size:max(16px, 1rem);">
<!-- Main Content -->
<table width="100%" align="center" cellspacing="0" cellpadding="0" border="0">
<tr>
<td align="center" valign="top" style="padding: 48px 16px 40px;">
<!-- ============ CARD ============ -->
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
style="max-width: 640px;">
<!-- Gradient Accent Strip -->
<tr>
<td bgcolor="#7733F5"
style="background: linear-gradient(90deg, #7733F5 0%, #60A5FA 35%, #EC4899 65%, #7733F5 100%); height: 6px; border-radius: 16px 16px 0 0; font-size: 0; line-height: 0;">
&nbsp;</td>
</tr>
<!-- Logo -->
<tr>
<td bgcolor="#FFFFFF" align="center" class="card-inner" style="padding: 32px 48px 24px;">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="AutoGPT" width="120"
style="max-width: 120px; display: inline-block;">
</td>
</tr>
<!-- Divider -->
<tr>
<td bgcolor="#FFFFFF" class="card-inner" style="padding: 0 48px;">
<table width="100%" cellpadding="0" cellspacing="0" border="0">
<tr>
<td style="border-top: 1px solid #E8E5F0; font-size: 0; line-height: 0; height: 1px;">&nbsp;</td>
</tr>
</table>
</td>
</tr>
<!-- Content -->
<tr>
<td class="card-inner" bgcolor="#FFFFFF"
style="padding: 36px 48px 44px; color: #1F1F20; font-family: 'Poppins', sans-serif; border-radius: 0 0 16px 16px;">
{{data.message|safe}}
</td>
</tr>
</table>
<!-- ============ END CARD ============ -->
<!-- Spacer -->
<table width="640" class="container" align="center" cellpadding="0" cellspacing="0" border="0"
style="max-width: 640px;">
<tr>
<td style="height: 40px; font-size: 0; line-height: 0;">&nbsp;</td>
</tr>
</table>
<!-- ============ FOOTER ============ -->
<td class="background" bgcolor="#070629" align="center" valign="top" style="padding: 0 8px;">
<!-- Email Content -->
<table class="container" align="center" width="640" cellpadding="0" cellspacing="0" border="0"
style="max-width: 640px;">
<tr>
<td align="center" style="padding: 0 48px;">
<!-- Logo Text -->
<p
style="font-family: 'Poppins', sans-serif; font-size: 22px; font-weight: 700; color: #FFFFFF; margin: 0 0 16px; letter-spacing: -0.5px;">
AutoGPT</p>
<!-- Social Icons -->
<table role="presentation" cellpadding="0" cellspacing="0" border="0" align="center">
<td align="center">
<!-- Logo Section -->
<table class="container ml-4 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td align="center" valign="middle" style="padding: 0 8px;">
<a href="https://x.com/auto_gpt" target="_blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/x.png"
width="22" alt="X" style="display: block; opacity: 0.6;">
</a>
<td class="ml-default-border container" height="40" style="line-height: 40px; min-width: 640px;">
</td>
<td align="center" valign="middle" style="padding: 0 8px;">
<a href="https://discord.gg/autogpt" target="_blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/discord.png"
width="22" alt="Discord" style="display: block; opacity: 0.6;">
</a>
</td>
<td align="center" valign="middle" style="padding: 0 8px;">
<a href="https://agpt.co/" target="_blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/white/website.png"
width="22" alt="Website" style="display: block; opacity: 0.6;">
</a>
</tr>
<tr>
<td>
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="row" align="center" style="padding: 0 50px;">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Spacer -->
<table width="100%" cellpadding="0" cellspacing="0" border="0">
<!-- Main Content Section -->
<table class="container ml-6 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td style="height: 20px; font-size: 0; line-height: 0;">&nbsp;</td>
<td class="row" style="padding: 0 50px;">
{{data.message|safe}}
</td>
</tr>
</table>
<!-- Divider -->
<table width="100%" cellpadding="0" cellspacing="0" border="0">
<!-- Signature Section -->
<table class="container ml-8 ml-default-border" width="640" bgcolor="#E2ECFD" align="center" border="0"
cellspacing="0" cellpadding="0" style="color: #070629; width: 640px; min-width: 640px;">
<tr>
<td
style="border-top: 1px solid rgba(255,255,255,0.08); font-size: 0; line-height: 0; height: 1px;">
&nbsp;</td>
<td class="row mobile-center" align="left" style="padding: 0 50px;">
<table class="ml-8 wrapper" border="0" cellspacing="0" cellpadding="0"
style="color: #070629; text-align: left;">
<tr>
<td class="col center mobile-center" align>
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 16px; line-height: 165%; margin-top: 0; margin-bottom: 0;">
Thank you for being a part of the AutoGPT community! Join the conversation on our Discord <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">here</a> and share your thoughts with us anytime.
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Footer Text -->
<p
style="font-family: 'Poppins', sans-serif; color: rgba(255,255,255,0.3); font-size: 12px; line-height: 165%; margin: 16px 0 0; text-align: center;">
AutoGPT &middot; 3rd Floor 1 Ashley Road, Altrincham, WA14 2DT, United Kingdom
</p>
<p
style="font-family: 'Poppins', sans-serif; color: rgba(255,255,255,0.3); font-size: 12px; line-height: 165%; margin: 4px 0 0; text-align: center;">
You received this email because you signed up on our website.
<a href="{{data.unsubscribe_link}}"
style="color: rgba(255,255,255,0.45); text-decoration: underline;">Unsubscribe</a>
</p>
<!-- Footer Section -->
<table class="container ml-10 ml-default-border" width="640" bgcolor="#ffffff" align="center" border="0"
cellspacing="0" cellpadding="0" style="width: 640px; min-width: 640px;">
<tr>
<td class="row" style="padding: 0 50px;">
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td>
<!-- Footer Content -->
<table align="center" width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="col" align="left" valign="middle" width="120">
<img
src="https://storage.mlcdn.com/account_image/597379/8QJ8kOjXakVvfe1kJLY2wWCObU1mp5EiDLfBlbQa.png"
border="0" alt="" width="120" class="logo"
style="max-width: 120px; display: inline-block;">
</td>
<td class="col" width="40" height="30" style="line-height: 30px;"></td>
<td class="col mobile-left" align="right" valign="middle" width="250">
<table role="presentation" cellpadding="0" cellspacing="0" border="0">
<tr>
<td align="center" valign="middle" width="18" style="padding: 0 5px 0 0;">
<a href="https://x.com/auto_gpt" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/x.png"
width="18" alt="x">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 5px;">
<a href="https://discord.gg/autogpt" target="blank"
style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/discord.png"
width="18" alt="discord">
</a>
</td>
<td align="center" valign="middle" width="18" style="padding: 0 0 0 5px;">
<a href="https://agpt.co/" target="blank" style="text-decoration: none;">
<img
src="https://assets.mlcdn.com/ml/images/icons/default/rounded_corners/black/website.png"
width="18" alt="website">
</a>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<h5
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 15px; line-height: 125%; font-weight: bold; font-style: normal; text-decoration: none; margin-bottom: 6px;">
AutoGPT
</h5>
</td>
</tr>
<tr>
<td align="center" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
3rd Floor 1 Ashley Road, Cheshire, United Kingdom, WA14 2DT, Altrincham<br>United Kingdom
</p>
</td>
</tr>
<tr>
<td height="8" style="line-height: 8px;"></td>
</tr>
<tr>
<td align="left" style="text-align: left!important;">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
You received this email because you signed up on our website.</p>
</td>
</tr>
<tr>
<td height="1" style="line-height: 12px;"></td>
</tr>
<tr>
<td align="left">
<p
style="font-family: 'Poppins', sans-serif; color: #070629; font-size: 14px; line-height: 150%; display: inline-block; margin-bottom: 0;">
<a href="{{data.unsubscribe_link}}"
style="color: #4285F4; font-weight: normal; font-style: normal; text-decoration: underline;">Unsubscribe</a>
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</td>
</tr>
</table>
</div>
</body>
</html>
</html>

View File

@@ -1,41 +0,0 @@
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
{{ email_body_html|safe }}
{% if approval_summary_html %}
<!-- Approval Callout -->
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
<tr>
<td bgcolor="#FFF3E6"
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
<table width="100%" cellpadding="0" cellspacing="0" border="0">
<tr>
<td style="padding: 0 0 12px 0;">
<span
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
Approval needed
</span>
</td>
</tr>
</table>
{{ approval_summary_html|safe }}
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
I thought this was a good idea. If you want it to happen, please hit approve.
</p>
</td>
</tr>
</table>
{% endif %}
<!-- CTA Button -->
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
<tr>
<td align="center" bgcolor="#7733F5"
style="background-color: #7733F5; border-radius: 12px;">
<a href="{{ cta_url }}"
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
{{ cta_label }}
</a>
</td>
</tr>
</table>
</div>

View File

@@ -1,58 +0,0 @@
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
<!-- Header -->
<h2
style="font-size: 24px; line-height: 130%; font-weight: 700; margin-top: 0; margin-bottom: 8px; color: #1F1F20; letter-spacing: -0.5px;">
Autopilot picked up where you left off
</h2>
<p style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 28px; color: #505057;">
We used your recent Copilot activity to prepare a concrete follow-up for you.
</p>
<!-- Divider -->
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-bottom: 28px;">
<tr>
<td style="border-top: 1px solid #E8E5F0; font-size: 0; line-height: 0; height: 1px;">&nbsp;</td>
</tr>
</table>
<!-- Body -->
{{ email_body_html|safe }}
{% if approval_summary_html %}
<!-- Approval Callout -->
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
<tr>
<td bgcolor="#FFF3E6"
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
<table width="100%" cellpadding="0" cellspacing="0" border="0">
<tr>
<td style="padding: 0 0 12px 0;">
<span
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
Approval needed
</span>
</td>
</tr>
</table>
{{ approval_summary_html|safe }}
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
I thought this was a good idea. If you want it to happen, please hit approve.
</p>
</td>
</tr>
</table>
{% endif %}
<!-- CTA Button -->
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
<tr>
<td align="center" bgcolor="#7733F5"
style="background-color: #7733F5; border-radius: 12px;">
<a href="{{ cta_url }}"
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
{{ cta_label }}
</a>
</td>
</tr>
</table>
</div>

View File

@@ -1,64 +0,0 @@
<div style="font-family: 'Poppins', sans-serif; color: #1F1F20;">
<!-- Header -->
<h2
style="font-size: 24px; line-height: 130%; font-weight: 700; margin-top: 0; margin-bottom: 8px; color: #1F1F20; letter-spacing: -0.5px;">
Your Autopilot beta access is waiting
</h2>
<p style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 28px; color: #505057;">
You applied to try Autopilot. Here is a tailored example of how it can help once you jump back in.
</p>
<!-- Highlight Card -->
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-bottom: 28px;">
<tr>
<td bgcolor="#5424AE"
style="background-color: #5424AE; border-radius: 12px; padding: 20px 24px;">
<p
style="font-size: 14px; line-height: 165%; margin-top: 0; margin-bottom: 0; color: #FFFFFF; font-weight: 500;">
Autopilot works in the background to handle tasks, surface insights, and take action on your behalf.
</p>
</td>
</tr>
</table>
<!-- Body -->
{{ email_body_html|safe }}
{% if approval_summary_html %}
<!-- Approval Callout -->
<table width="100%" cellpadding="0" cellspacing="0" border="0" style="margin-top: 28px; margin-bottom: 8px;">
<tr>
<td bgcolor="#FFF3E6"
style="background-color: #FFF3E6; border-left: 4px solid #FE8700; border-radius: 12px; padding: 20px 24px;">
<table width="100%" cellpadding="0" cellspacing="0" border="0">
<tr>
<td style="padding: 0 0 12px 0;">
<span
style="display: inline-block; background-color: #FE8700; color: #FFFFFF; font-size: 11px; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; padding: 4px 10px; border-radius: 999px;">
Approval needed
</span>
</td>
</tr>
</table>
{{ approval_summary_html|safe }}
<p style="font-size: 14px; line-height: 165%; margin-top: 8px; margin-bottom: 0; color: #505057;">
I thought this was a good idea. If you want it to happen, please hit approve.
</p>
</td>
</tr>
</table>
{% endif %}
<!-- CTA Button -->
<table cellpadding="0" cellspacing="0" border="0" style="margin-top: 32px;">
<tr>
<td align="center" bgcolor="#7733F5"
style="background-color: #7733F5; border-radius: 12px;">
<a href="{{ cta_url }}"
style="display: inline-block; padding: 16px 36px; background-color: #7733F5; color: #FFFFFF; text-decoration: none; font-family: 'Poppins', sans-serif; font-weight: 600; font-size: 16px; border-radius: 12px; line-height: 1;">
{{ cta_label }}
</a>
</td>
</tr>
</table>
</div>

View File

@@ -39,7 +39,6 @@ class Flag(str, Enum):
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
CHAT = "chat"
COPILOT_SDK = "copilot-sdk"
NIGHTLY_COPILOT = "nightly-copilot"
def is_configured() -> bool:

View File

@@ -275,13 +275,12 @@ async def store_media_file(
# Process file
elif file.startswith("data:"):
# Data URI
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
if not match:
parsed_uri = parse_data_uri(file)
if parsed_uri is None:
raise ValueError(
"Invalid data URI format. Expected data:<mime>;base64,<data>"
)
mime_type = match.group(1).strip().lower()
b64_content = match.group(2).strip()
mime_type, b64_content = parsed_uri
# Generate filename and decode
extension = _extension_from_mime(mime_type)
@@ -415,13 +414,70 @@ def get_dir_size(path: Path) -> int:
return total
async def resolve_media_content(
content: MediaFileType,
execution_context: "ExecutionContext",
*,
return_format: MediaReturnFormat,
) -> MediaFileType:
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
Plain text content (source code, filenames) is returned unchanged. Media
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
:func:`store_media_file` using *return_format*.
Use this when a block field is typed as ``MediaFileType`` but may contain
either literal text or a media reference.
"""
if not content or not is_media_file_ref(content):
return content
return await store_media_file(
content, execution_context, return_format=return_format
)
def is_media_file_ref(value: str) -> bool:
"""Return True if *value* looks like a ``MediaFileType`` reference.
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
formats accepted by :func:`store_media_file`. Plain text content
(e.g. source code, filenames) returns False.
Known limitation: HTTP(S) URL detection is heuristic. Any string that
starts with ``http://`` or ``https://`` is treated as a media URL, even
if it appears as a URL inside source-code comments or documentation.
Blocks that produce source code or Markdown as output may therefore
trigger false positives. Callers that need higher precision should
inspect the string further (e.g. verify the URL is reachable or has a
media-friendly extension).
Note: this does *not* match local file paths, which are ambiguous
(could be filenames or actual paths). Blocks that need to resolve
local paths should check for them separately.
"""
return value.startswith(("data:", "workspace://", "http://", "https://"))
def parse_data_uri(value: str) -> tuple[str, str] | None:
"""Parse a ``data:<mime>;base64,<payload>`` URI.
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
or ``None`` if it is not.
"""
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
if not match:
return None
return match.group(1).strip().lower(), match.group(2).strip()
def get_mime_type(file: str) -> str:
"""
Get the MIME type of a file, whether it's a data URI, URL, or local path.
"""
if file.startswith("data:"):
match = re.match(r"^data:([^;]+);base64,", file)
return match.group(1) if match else "application/octet-stream"
parsed_uri = parse_data_uri(file)
return parsed_uri[0] if parsed_uri else "application/octet-stream"
elif file.startswith(("http://", "https://")):
parsed_url = urlparse(file)

View File

@@ -0,0 +1,375 @@
"""Parse file content into structured Python objects based on file format.
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
formats into native Python types *before* schema-driven coercion runs. This
lets blocks with ``Any``-typed inputs receive structured data rather than raw
strings, while blocks expecting strings get the value coerced back via
``convert()``.
Supported formats:
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
when all lines are dicts with the same keys (tabular data), output is
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
otherwise returns a plain ``list`` of parsed values
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
The **fallback contract** is enforced by :func:`parse_file_content`, not by
individual parser functions. If any parser raises, ``parse_file_content``
catches the exception and returns the original content unchanged (string for
text formats, bytes for binary formats). Callers should never see an
exception from the public API when ``strict=False``.
"""
import csv
import io
import json
import logging
import tomllib
import zipfile
from collections.abc import Callable
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
# unlike os.path.splitext which uses platform-native separators.
from posixpath import splitext
from typing import Any
import yaml
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Extension / MIME → format label mapping
# ---------------------------------------------------------------------------
_EXT_TO_FORMAT: dict[str, str] = {
".json": "json",
".jsonl": "jsonl",
".ndjson": "jsonl",
".csv": "csv",
".tsv": "tsv",
".yaml": "yaml",
".yml": "yaml",
".toml": "toml",
".parquet": "parquet",
".xlsx": "xlsx",
}
MIME_TO_FORMAT: dict[str, str] = {
"application/json": "json",
"application/x-ndjson": "jsonl",
"application/jsonl": "jsonl",
"text/csv": "csv",
"text/tab-separated-values": "tsv",
"application/x-yaml": "yaml",
"application/yaml": "yaml",
"text/yaml": "yaml",
"application/toml": "toml",
"application/vnd.apache.parquet": "parquet",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
}
# Formats that require raw bytes rather than decoded text.
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
# ---------------------------------------------------------------------------
# Public API (top-down: main functions first, helpers below)
# ---------------------------------------------------------------------------
def infer_format_from_uri(uri: str) -> str | None:
"""Return a format label based on URI extension or MIME fragment.
Returns ``None`` when the format cannot be determined — the caller should
fall back to returning the content as a plain string.
"""
# 1. Check MIME fragment (workspace://abc123#application/json)
if "#" in uri:
_, fragment = uri.rsplit("#", 1)
fmt = MIME_TO_FORMAT.get(fragment.lower())
if fmt:
return fmt
# 2. Check file extension from the path portion.
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
path = uri.split("#")[0].split("?")[0]
_, ext = splitext(path)
fmt = _EXT_TO_FORMAT.get(ext.lower())
if fmt is not None:
return fmt
# Legacy .xls is not supported — map it so callers can produce a
# user-friendly error instead of returning garbled binary.
if ext.lower() == ".xls":
return "xls"
return None
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
"""Parse *content* according to *fmt* and return a native Python value.
When *strict* is ``False`` (default), returns the original *content*
unchanged if *fmt* is not recognised or parsing fails for any reason.
This mode **never raises**.
When *strict* is ``True``, parsing errors are propagated to the caller.
Unrecognised formats or type mismatches (e.g. text for a binary format)
still return *content* unchanged without raising.
"""
if fmt == "xls":
return (
"[Unsupported format] Legacy .xls files are not supported. "
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
)
try:
if fmt in BINARY_FORMATS:
parser = _BINARY_PARSERS.get(fmt)
if parser is None:
return content
if isinstance(content, str):
# Caller gave us text for a binary format — can't parse.
return content
return parser(content)
parser = _TEXT_PARSERS.get(fmt)
if parser is None:
return content
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
return parser(content)
except PARSE_EXCEPTIONS:
if strict:
raise
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
return content
# ---------------------------------------------------------------------------
# Exception loading helpers
# ---------------------------------------------------------------------------
def _load_openpyxl_exception() -> type[Exception]:
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
return InvalidFileException
def _load_arrow_exception() -> type[Exception]:
"""Return pyarrow's ArrowException, raising ImportError if absent."""
from pyarrow import ArrowException # noqa: PLC0415
return ArrowException
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
try:
return loader()
except ImportError:
return None
# Exception types that can be raised during file content parsing.
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
#
# Optional-dependency exception types are loaded via a helper that raises
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
# This ensures mypy sees clean types and missing deps surface as real errors.
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
exc
for exc in (
json.JSONDecodeError,
csv.Error,
yaml.YAMLError,
tomllib.TOMLDecodeError,
ValueError,
UnicodeDecodeError,
ImportError,
OSError,
KeyError,
TypeError,
zipfile.BadZipFile,
_optional_exc(_load_openpyxl_exception),
# ArrowException covers ArrowIOError and ArrowCapacityError which
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
# already map to ValueError/TypeError but this catches the rest.
_optional_exc(_load_arrow_exception),
)
if exc is not None
)
# ---------------------------------------------------------------------------
# Text-based parsers (content: str → Any)
# ---------------------------------------------------------------------------
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
"""Parse *content* and return the result only if it is a container (list/dict).
Scalar values (strings, numbers, booleans, None) are discarded and the
original *content* string is returned instead. This prevents e.g. a JSON
file containing just ``"42"`` from silently becoming an int.
"""
parsed = parser(content)
if isinstance(parsed, (list, dict)):
return parsed
return content
def _parse_json(content: str) -> list | dict | str:
return _parse_container(json.loads, content)
def _parse_jsonl(content: str) -> Any:
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
if not lines:
return content
# When every line is a dict with the same keys, convert to table format
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
keys = list(lines[0].keys())
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
keys_tuple = tuple(keys)
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
return [keys] + [[obj[k] for k in keys] for obj in lines]
return lines
def _parse_csv(content: str) -> Any:
return _parse_delimited(content, delimiter=",")
def _parse_tsv(content: str) -> Any:
return _parse_delimited(content, delimiter="\t")
def _parse_delimited(content: str, *, delimiter: str) -> Any:
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
# csv.reader never yields [] — blank lines yield [""]. Filter out
# rows where every cell is empty (i.e. truly blank lines).
rows = [row for row in reader if _row_has_content(row)]
if not rows:
return content
# If the declared delimiter produces only single-column rows, try
# sniffing the actual delimiter — catches misidentified files (e.g.
# a tab-delimited file with a .csv extension).
if len(rows[0]) == 1:
try:
dialect = csv.Sniffer().sniff(content[:8192])
if dialect.delimiter != delimiter:
reader = csv.reader(io.StringIO(content), dialect)
rows = [row for row in reader if _row_has_content(row)]
except csv.Error:
pass
if rows and len(rows[0]) >= 2:
return rows
return content
def _row_has_content(row: list[str]) -> bool:
"""Return True when *row* contains at least one non-empty cell.
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
This predicate filters those out consistently across the initial read
and the sniffer-fallback re-read.
"""
return any(cell for cell in row)
def _parse_yaml(content: str) -> list | dict | str:
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
# safe_load prevents code execution; for production hardening consider
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
if "\n---" in content or content.startswith("---\n"):
# Multi-document YAML: only the first document is parsed; the rest
# are silently ignored by yaml.safe_load. Warn so callers are aware.
logger.warning(
"Multi-document YAML detected (--- separator); "
"only the first document will be parsed."
)
return _parse_container(yaml.safe_load, content)
def _parse_toml(content: str) -> Any:
parsed = tomllib.loads(content)
# tomllib.loads always returns a dict — return it even if empty.
return parsed
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
"json": _parse_json,
"jsonl": _parse_jsonl,
"csv": _parse_csv,
"tsv": _parse_tsv,
"yaml": _parse_yaml,
"toml": _parse_toml,
}
# ---------------------------------------------------------------------------
# Binary-based parsers (content: bytes → Any)
# ---------------------------------------------------------------------------
def _parse_parquet(content: bytes) -> list[list[Any]]:
import pandas as pd
df = pd.read_parquet(io.BytesIO(content))
return _df_to_rows(df)
def _parse_xlsx(content: bytes) -> list[list[Any]]:
import pandas as pd
# Explicitly specify openpyxl engine; the default engine varies by pandas
# version and does not support legacy .xls (which is excluded by our format map).
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
return _df_to_rows(df)
def _df_to_rows(df: Any) -> list[list[Any]]:
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
NaN values are replaced with ``None`` so the result is JSON-serializable.
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
silently converts ``None`` back to ``NaN`` in float64 columns.
"""
header = df.columns.tolist()
rows = [
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
]
return [header] + rows
def _is_nan(cell: Any) -> bool:
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
``pd.isna()`` on a list/dict returns a boolean array which raises
``ValueError`` in a boolean context. Guard with a scalar check first.
"""
import pandas as pd
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
"parquet": _parse_parquet,
"xlsx": _parse_xlsx,
}

View File

@@ -0,0 +1,624 @@
"""Tests for file_content_parser — format inference and structured parsing."""
import io
import json
import pytest
from backend.util.file_content_parser import (
BINARY_FORMATS,
infer_format_from_uri,
parse_file_content,
)
# ---------------------------------------------------------------------------
# infer_format_from_uri
# ---------------------------------------------------------------------------
class TestInferFormat:
# --- extension-based ---
def test_json_extension(self):
assert infer_format_from_uri("/home/user/data.json") == "json"
def test_jsonl_extension(self):
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
def test_ndjson_extension(self):
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
def test_csv_extension(self):
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
def test_tsv_extension(self):
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
def test_yaml_extension(self):
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
def test_yml_extension(self):
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
def test_toml_extension(self):
assert infer_format_from_uri("/home/user/config.toml") == "toml"
def test_parquet_extension(self):
assert infer_format_from_uri("/data/table.parquet") == "parquet"
def test_xlsx_extension(self):
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
def test_xls_extension_returns_xls_label(self):
# Legacy .xls is mapped so callers can produce a helpful error.
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
def test_case_insensitive(self):
assert infer_format_from_uri("/data/FILE.JSON") == "json"
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
def test_unicode_filename(self):
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
def test_unknown_extension(self):
assert infer_format_from_uri("/home/user/readme.txt") is None
def test_no_extension(self):
assert infer_format_from_uri("workspace://abc123") is None
# --- MIME-based ---
def test_mime_json(self):
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
def test_mime_csv(self):
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
def test_mime_tsv(self):
assert (
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
== "tsv"
)
def test_mime_ndjson(self):
assert (
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
)
def test_mime_yaml(self):
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
def test_mime_xlsx(self):
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
assert infer_format_from_uri(uri) == "xlsx"
def test_mime_parquet(self):
assert (
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
== "parquet"
)
def test_unknown_mime(self):
assert infer_format_from_uri("workspace://abc123#text/plain") is None
def test_unknown_mime_falls_through_to_extension(self):
# Unknown MIME (text/plain) should fall through to extension-based detection.
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
# --- MIME takes precedence over extension ---
def test_mime_overrides_extension(self):
# .txt extension but JSON MIME → json
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
# ---------------------------------------------------------------------------
# parse_file_content — JSON
# ---------------------------------------------------------------------------
class TestParseJson:
def test_array(self):
result = parse_file_content("[1, 2, 3]", "json")
assert result == [1, 2, 3]
def test_object(self):
result = parse_file_content('{"key": "value"}', "json")
assert result == {"key": "value"}
def test_nested(self):
content = json.dumps({"rows": [[1, 2], [3, 4]]})
result = parse_file_content(content, "json")
assert result == {"rows": [[1, 2], [3, 4]]}
def test_scalar_string_stays_as_string(self):
result = parse_file_content('"hello"', "json")
assert result == '"hello"' # original content, not parsed
def test_scalar_number_stays_as_string(self):
result = parse_file_content("42", "json")
assert result == "42"
def test_scalar_boolean_stays_as_string(self):
result = parse_file_content("true", "json")
assert result == "true"
def test_null_stays_as_string(self):
result = parse_file_content("null", "json")
assert result == "null"
def test_invalid_json_fallback(self):
content = "not json at all"
result = parse_file_content(content, "json")
assert result == content
def test_empty_string_fallback(self):
result = parse_file_content("", "json")
assert result == ""
def test_bytes_input_decoded(self):
result = parse_file_content(b"[1, 2, 3]", "json")
assert result == [1, 2, 3]
# ---------------------------------------------------------------------------
# parse_file_content — JSONL
# ---------------------------------------------------------------------------
class TestParseJsonl:
def test_tabular_uniform_dicts_to_table_format(self):
"""JSONL with uniform dict keys → table format (header + rows),
consistent with CSV/TSV/Parquet/Excel output."""
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
result = parse_file_content(content, "jsonl")
assert result == [
["name", "color"],
["apple", "red"],
["banana", "yellow"],
["cherry", "red"],
]
def test_tabular_single_key_dicts(self):
"""JSONL with single-key uniform dicts → table format."""
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
result = parse_file_content(content, "jsonl")
assert result == [["a"], [1], [2], [3]]
def test_tabular_blank_lines_skipped(self):
content = '{"a": 1}\n\n{"a": 2}\n'
result = parse_file_content(content, "jsonl")
assert result == [["a"], [1], [2]]
def test_heterogeneous_dicts_stay_as_list(self):
"""JSONL with different keys across objects → list of dicts (no table)."""
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
result = parse_file_content(content, "jsonl")
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
def test_partially_overlapping_keys_stay_as_list(self):
"""JSONL dicts with partially overlapping keys → list of dicts."""
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
result = parse_file_content(content, "jsonl")
assert result == [
{"name": "apple", "color": "red"},
{"name": "banana", "size": "medium"},
]
def test_mixed_types_stay_as_list(self):
"""JSONL with non-dict lines → list of parsed values (no table)."""
content = '1\n"hello"\n[1,2]\n'
result = parse_file_content(content, "jsonl")
assert result == [1, "hello", [1, 2]]
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
"""JSONL mixing dicts and non-dicts → list of parsed values."""
content = '{"a": 1}\n42\n{"b": 2}'
result = parse_file_content(content, "jsonl")
assert result == [{"a": 1}, 42, {"b": 2}]
def test_tabular_preserves_key_order(self):
"""Table header should follow the key order of the first object."""
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
result = parse_file_content(content, "jsonl")
assert result[0] == ["z", "a"] # order from first object
assert result[1] == [1, 2]
assert result[2] == [3, 4]
def test_single_dict_stays_as_list(self):
"""Single-line JSONL with one dict → [dict], NOT a table.
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
content = '{"a": 1, "b": 2}'
result = parse_file_content(content, "jsonl")
assert result == [{"a": 1, "b": 2}]
def test_tabular_with_none_values(self):
"""Uniform keys but some null values → table with None cells."""
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
result = parse_file_content(content, "jsonl")
assert result == [
["name", "color"],
["apple", "red"],
["banana", None],
]
def test_empty_file_fallback(self):
result = parse_file_content("", "jsonl")
assert result == ""
def test_all_blank_lines_fallback(self):
result = parse_file_content("\n\n\n", "jsonl")
assert result == "\n\n\n"
def test_invalid_line_fallback(self):
content = '{"a": 1}\nnot json\n'
result = parse_file_content(content, "jsonl")
assert result == content # fallback
# ---------------------------------------------------------------------------
# parse_file_content — CSV
# ---------------------------------------------------------------------------
class TestParseCsv:
def test_basic(self):
content = "Name,Score\nAlice,90\nBob,85"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_quoted_fields(self):
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
result = parse_file_content(content, "csv")
assert result[1] == ["Alice", "Loves, commas"]
def test_single_column_fallback(self):
# Only 1 column — not tabular enough.
content = "Name\nAlice\nBob"
result = parse_file_content(content, "csv")
assert result == content
def test_empty_rows_skipped(self):
content = "A,B\n\n1,2\n\n3,4"
result = parse_file_content(content, "csv")
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
def test_empty_file_fallback(self):
result = parse_file_content("", "csv")
assert result == ""
def test_utf8_bom(self):
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
bom = "\ufeff"
content = bom + "Name,Score\nAlice,90\nBob,85"
result = parse_file_content(content, "csv")
# The BOM may be part of the first header cell; ensure rows are still parsed.
assert len(result) == 3
assert result[1] == ["Alice", "90"]
assert result[2] == ["Bob", "85"]
# ---------------------------------------------------------------------------
# parse_file_content — TSV
# ---------------------------------------------------------------------------
class TestParseTsv:
def test_basic(self):
content = "Name\tScore\nAlice\t90\nBob\t85"
result = parse_file_content(content, "tsv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_single_column_fallback(self):
content = "Name\nAlice\nBob"
result = parse_file_content(content, "tsv")
assert result == content
# ---------------------------------------------------------------------------
# parse_file_content — YAML
# ---------------------------------------------------------------------------
class TestParseYaml:
def test_list(self):
content = "- apple\n- banana\n- cherry"
result = parse_file_content(content, "yaml")
assert result == ["apple", "banana", "cherry"]
def test_dict(self):
content = "name: Alice\nage: 30"
result = parse_file_content(content, "yaml")
assert result == {"name": "Alice", "age": 30}
def test_nested(self):
content = "users:\n - name: Alice\n - name: Bob"
result = parse_file_content(content, "yaml")
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
def test_scalar_stays_as_string(self):
result = parse_file_content("hello world", "yaml")
assert result == "hello world"
def test_invalid_yaml_fallback(self):
content = ":\n :\n invalid: - -"
result = parse_file_content(content, "yaml")
# Malformed YAML should fall back to the original string, not raise.
assert result == content
# ---------------------------------------------------------------------------
# parse_file_content — TOML
# ---------------------------------------------------------------------------
class TestParseToml:
def test_basic(self):
content = '[server]\nhost = "localhost"\nport = 8080'
result = parse_file_content(content, "toml")
assert result == {"server": {"host": "localhost", "port": 8080}}
def test_flat(self):
content = 'name = "test"\ncount = 42'
result = parse_file_content(content, "toml")
assert result == {"name": "test", "count": 42}
def test_empty_string_returns_empty_dict(self):
result = parse_file_content("", "toml")
assert result == {}
def test_invalid_toml_fallback(self):
result = parse_file_content("not = [valid toml", "toml")
assert result == "not = [valid toml"
# ---------------------------------------------------------------------------
# parse_file_content — Parquet (binary)
# ---------------------------------------------------------------------------
try:
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
_has_pyarrow = True
except ImportError:
_has_pyarrow = False
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
class TestParseParquet:
@pytest.fixture
def parquet_bytes(self) -> bytes:
import pandas as pd
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
buf = io.BytesIO()
df.to_parquet(buf, index=False)
return buf.getvalue()
def test_basic(self, parquet_bytes: bytes):
result = parse_file_content(parquet_bytes, "parquet")
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
def test_string_input_fallback(self):
# Parquet is binary — string input can't be parsed.
result = parse_file_content("not parquet", "parquet")
assert result == "not parquet"
def test_invalid_bytes_fallback(self):
result = parse_file_content(b"not parquet bytes", "parquet")
assert result == b"not parquet bytes"
def test_empty_bytes_fallback(self):
"""Empty binary input should return the empty bytes, not crash."""
result = parse_file_content(b"", "parquet")
assert result == b""
def test_nan_replaced_with_none(self):
"""NaN values in Parquet must become None for JSON serializability."""
import math
import pandas as pd
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
buf = io.BytesIO()
df.to_parquet(buf, index=False)
result = parse_file_content(buf.getvalue(), "parquet")
# Row with NaN in float col → None
assert result[2][0] is None # float NaN → None
assert result[2][1] is None # str None → None
# Ensure no NaN leaks
for row in result[1:]:
for cell in row:
if isinstance(cell, float):
assert not math.isnan(cell), f"NaN leaked: {row}"
# ---------------------------------------------------------------------------
# parse_file_content — Excel (binary)
# ---------------------------------------------------------------------------
class TestParseExcel:
@pytest.fixture
def xlsx_bytes(self) -> bytes:
import pandas as pd
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
buf = io.BytesIO()
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
return buf.getvalue()
def test_basic(self, xlsx_bytes: bytes):
result = parse_file_content(xlsx_bytes, "xlsx")
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
def test_string_input_fallback(self):
result = parse_file_content("not xlsx", "xlsx")
assert result == "not xlsx"
def test_invalid_bytes_fallback(self):
result = parse_file_content(b"not xlsx bytes", "xlsx")
assert result == b"not xlsx bytes"
def test_empty_bytes_fallback(self):
"""Empty binary input should return the empty bytes, not crash."""
result = parse_file_content(b"", "xlsx")
assert result == b""
def test_nan_replaced_with_none(self):
"""NaN values in float columns must become None for JSON serializability."""
import math
import pandas as pd
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
buf = io.BytesIO()
df.to_excel(buf, index=False) # type: ignore[arg-type]
result = parse_file_content(buf.getvalue(), "xlsx")
# Row with NaN in float col → None, not float('nan')
assert result[2][0] is None # float NaN → None
assert result[3][1] is None # str None → None
# Ensure no NaN leaks
for row in result[1:]: # skip header
for cell in row:
if isinstance(cell, float):
assert not math.isnan(cell), f"NaN leaked: {row}"
# ---------------------------------------------------------------------------
# parse_file_content — unknown format / fallback
# ---------------------------------------------------------------------------
class TestFallback:
def test_unknown_format_returns_content(self):
result = parse_file_content("hello world", "xml")
assert result == "hello world"
def test_none_format_returns_content(self):
# Shouldn't normally be called with unrecognised format, but must not crash.
result = parse_file_content("hello", "unknown_format")
assert result == "hello"
# ---------------------------------------------------------------------------
# BINARY_FORMATS
# ---------------------------------------------------------------------------
class TestBinaryFormats:
def test_parquet_is_binary(self):
assert "parquet" in BINARY_FORMATS
def test_xlsx_is_binary(self):
assert "xlsx" in BINARY_FORMATS
def test_text_formats_not_binary(self):
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
assert fmt not in BINARY_FORMATS
# ---------------------------------------------------------------------------
# MIME mapping
# ---------------------------------------------------------------------------
class TestMimeMapping:
def test_application_yaml(self):
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
# ---------------------------------------------------------------------------
# CSV sniffer fallback
# ---------------------------------------------------------------------------
class TestCsvSnifferFallback:
def test_tab_delimited_with_csv_format(self):
"""Tab-delimited content parsed as csv should use sniffer fallback."""
content = "Name\tScore\nAlice\t90\nBob\t85"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_sniffer_failure_returns_content(self):
"""When sniffer fails, single-column falls back to raw content."""
content = "Name\nAlice\nBob"
result = parse_file_content(content, "csv")
assert result == content
# ---------------------------------------------------------------------------
# OpenpyxlInvalidFile fallback
# ---------------------------------------------------------------------------
class TestOpenpyxlFallback:
def test_invalid_xlsx_non_strict(self):
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
result = parse_file_content(b"not xlsx bytes", "xlsx")
assert result == b"not xlsx bytes"
# ---------------------------------------------------------------------------
# Header-only CSV
# ---------------------------------------------------------------------------
class TestHeaderOnlyCsv:
def test_header_only_csv_returns_header_row(self):
"""CSV with only a header row (no data rows) should return [[header]]."""
content = "Name,Score"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"]]
def test_header_only_csv_with_trailing_newline(self):
content = "Name,Score\n"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"]]
# ---------------------------------------------------------------------------
# Binary format + line range (line range ignored for binary formats)
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
class TestBinaryFormatLineRange:
def test_parquet_ignores_line_range(self):
"""Binary formats should parse the full file regardless of line range.
Line ranges are meaningless for binary formats (parquet/xlsx) — the
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
should return the complete structured data.
"""
import pandas as pd
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
buf = io.BytesIO()
df.to_parquet(buf, index=False)
# parse_file_content itself doesn't take a line range — this tests
# that the full content is parsed even though the bytes could have
# been truncated upstream (it's not, by design).
result = parse_file_content(buf.getvalue(), "parquet")
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
# ---------------------------------------------------------------------------
# Legacy .xls UX
# ---------------------------------------------------------------------------
class TestXlsFallback:
def test_xls_returns_helpful_error_string(self):
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
assert isinstance(result, str)
assert ".xlsx" in result
assert "not supported" in result.lower()
def test_xls_with_string_content(self):
result = parse_file_content("some text", "xls")
assert isinstance(result, str)
assert ".xlsx" in result

View File

@@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.file import (
is_media_file_ref,
parse_data_uri,
resolve_media_content,
store_media_file,
)
from backend.util.type import MediaFileType
@@ -344,3 +349,162 @@ class TestFileCloudIntegration:
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# ---------------------------------------------------------------------------
# is_media_file_ref
# ---------------------------------------------------------------------------
class TestIsMediaFileRef:
def test_data_uri(self):
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
def test_workspace_uri(self):
assert is_media_file_ref("workspace://abc123") is True
def test_workspace_uri_with_mime(self):
assert is_media_file_ref("workspace://abc123#image/png") is True
def test_http_url(self):
assert is_media_file_ref("http://example.com/image.png") is True
def test_https_url(self):
assert is_media_file_ref("https://example.com/image.png") is True
def test_plain_text(self):
assert is_media_file_ref("print('hello')") is False
def test_local_path(self):
assert is_media_file_ref("/tmp/file.txt") is False
def test_empty_string(self):
assert is_media_file_ref("") is False
def test_filename(self):
assert is_media_file_ref("image.png") is False
# ---------------------------------------------------------------------------
# parse_data_uri
# ---------------------------------------------------------------------------
class TestParseDataUri:
def test_valid_png(self):
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
assert result is not None
mime, payload = result
assert mime == "image/png"
assert payload == "iVBORw0KGg=="
def test_valid_text(self):
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
assert result is not None
assert result[0] == "text/plain"
assert result[1] == "SGVsbG8="
def test_mime_case_normalized(self):
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
assert result is not None
assert result[0] == "image/png"
def test_not_data_uri(self):
assert parse_data_uri("workspace://abc123") is None
def test_plain_text(self):
assert parse_data_uri("hello world") is None
def test_missing_base64(self):
assert parse_data_uri("data:image/png;utf-8,abc") is None
def test_empty_payload(self):
result = parse_data_uri("data:image/png;base64,")
assert result is not None
assert result[1] == ""
# ---------------------------------------------------------------------------
# resolve_media_content
# ---------------------------------------------------------------------------
class TestResolveMediaContent:
@pytest.mark.asyncio
async def test_plain_text_passthrough(self):
"""Plain text content (not a media ref) passes through unchanged."""
ctx = make_test_context()
result = await resolve_media_content(
MediaFileType("print('hello')"),
ctx,
return_format="for_external_api",
)
assert result == "print('hello')"
@pytest.mark.asyncio
async def test_empty_string_passthrough(self):
"""Empty string passes through unchanged."""
ctx = make_test_context()
result = await resolve_media_content(
MediaFileType(""),
ctx,
return_format="for_external_api",
)
assert result == ""
@pytest.mark.asyncio
async def test_media_ref_delegates_to_store(self):
"""Media references are resolved via store_media_file."""
ctx = make_test_context()
with patch(
"backend.util.file.store_media_file",
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
) as mock_store:
result = await resolve_media_content(
MediaFileType("workspace://img123"),
ctx,
return_format="for_external_api",
)
assert result == "data:image/png;base64,abc"
mock_store.assert_called_once_with(
MediaFileType("workspace://img123"),
ctx,
return_format="for_external_api",
)
@pytest.mark.asyncio
async def test_data_uri_delegates_to_store(self):
"""Data URIs are also resolved via store_media_file."""
ctx = make_test_context()
data_uri = "data:image/png;base64,iVBORw0KGg=="
with patch(
"backend.util.file.store_media_file",
new=AsyncMock(return_value=MediaFileType(data_uri)),
) as mock_store:
result = await resolve_media_content(
MediaFileType(data_uri),
ctx,
return_format="for_external_api",
)
assert result == data_uri
mock_store.assert_called_once()
@pytest.mark.asyncio
async def test_https_url_delegates_to_store(self):
"""HTTPS URLs are resolved via store_media_file."""
ctx = make_test_context()
with patch(
"backend.util.file.store_media_file",
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
) as mock_store:
result = await resolve_media_content(
MediaFileType("https://example.com/image.png"),
ctx,
return_format="for_local_processing",
)
assert result == "data:image/png;base64,abc"
mock_store.assert_called_once_with(
MediaFileType("https://example.com/image.png"),
ctx,
return_format="for_local_processing",
)

View File

@@ -1,7 +1,6 @@
import json
import os
import re
from datetime import date
from enum import Enum
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
@@ -126,22 +125,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=True,
description="If the invite-only signup gate is enforced",
)
nightly_copilot_callback_start_date: date = Field(
default=date(2026, 2, 8),
description="Users with sessions since this date are eligible for the one-off autopilot callback cohort.",
)
nightly_copilot_invite_cta_start_date: date = Field(
default=date(2026, 3, 13),
description="Invite CTA cohort does not run before this date.",
)
nightly_copilot_invite_cta_delay_hours: int = Field(
default=48,
description="Delay after invite creation before the invite CTA can run.",
)
nightly_copilot_callback_token_ttl_hours: int = Field(
default=24 * 14,
description="TTL for nightly copilot callback tokens.",
)
enable_credit: bool = Field(
default=False,
description="If user credit system is enabled or not",

View File

@@ -86,13 +86,9 @@ class TextFormatter:
"i",
"img",
"li",
"ol",
"p",
"span",
"strong",
"table",
"td",
"tr",
"u",
"ul",
]
@@ -102,15 +98,6 @@ class TextFormatter:
"*": ["class", "style"],
"a": ["href"],
"img": ["src"],
"table": [
"align",
"border",
"cellpadding",
"cellspacing",
"role",
"width",
],
"td": ["align", "bgcolor", "colspan", "height", "valign", "width"],
}
def format_string(self, template_str: str, values=None, **kwargs) -> str:

View File

@@ -1,9 +0,0 @@
from backend.util.settings import Settings
settings = Settings()
def get_frontend_base_url() -> str:
return (
settings.config.frontend_base_url or settings.config.platform_base_url
).rstrip("/")

View File

@@ -183,7 +183,8 @@ class WorkspaceManager:
f"{Config().max_file_size_mb}MB limit"
)
# Virus scan content before persisting (defense in depth)
# Scan here — callers must NOT duplicate this scan.
# WorkspaceManager owns virus scanning for all persisted files.
await scan_content_safe(content, filename=filename)
# Determine path with session scoping

View File

@@ -1,67 +0,0 @@
-- CreateEnum
CREATE TYPE "ChatSessionStartType" AS ENUM(
'MANUAL',
'AUTOPILOT_NIGHTLY',
'AUTOPILOT_CALLBACK',
'AUTOPILOT_INVITE_CTA'
);
-- AlterTable
ALTER TABLE "ChatSession"
ADD COLUMN "startType" "ChatSessionStartType" NOT NULL DEFAULT 'MANUAL',
ADD COLUMN "executionTag" TEXT,
ADD COLUMN "sessionConfig" JSONB NOT NULL DEFAULT '{}',
ADD COLUMN "completionReport" JSONB,
ADD COLUMN "completionReportRepairCount" INTEGER NOT NULL DEFAULT 0,
ADD COLUMN "completionReportRepairQueuedAt" TIMESTAMP(3),
ADD COLUMN "completedAt" TIMESTAMP(3),
ADD COLUMN "notificationEmailSentAt" TIMESTAMP(3),
ADD COLUMN "notificationEmailSkippedAt" TIMESTAMP(3);
COMMENT ON COLUMN "ChatSession"."sessionConfig" IS 'Validated by backend.copilot.session_types.ChatSessionConfig';
COMMENT ON COLUMN "ChatSession"."completionReport" IS 'Validated by backend.copilot.session_types.StoredCompletionReport';
-- CreateTable
CREATE TABLE "ChatSessionCallbackToken"(
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"sourceSessionId" TEXT,
"callbackSessionMessage" TEXT NOT NULL,
"expiresAt" TIMESTAMP(3) NOT NULL,
"consumedAt" TIMESTAMP(3),
"consumedSessionId" TEXT,
CONSTRAINT "ChatSessionCallbackToken_pkey" PRIMARY KEY("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "ChatSession_userId_executionTag_key"
ON "ChatSession"("userId",
"executionTag");
-- CreateIndex
CREATE INDEX "ChatSession_userId_startType_updatedAt_idx"
ON "ChatSession"("userId",
"startType",
"updatedAt");
-- CreateIndex
CREATE INDEX "ChatSessionCallbackToken_userId_expiresAt_idx"
ON "ChatSessionCallbackToken"("userId",
"expiresAt");
-- CreateIndex
CREATE INDEX "ChatSessionCallbackToken_consumedSessionId_idx"
ON "ChatSessionCallbackToken"("consumedSessionId");
-- AddForeignKey
ALTER TABLE "ChatSessionCallbackToken" ADD CONSTRAINT "ChatSessionCallbackToken_userId_fkey" FOREIGN KEY("userId") REFERENCES "User"("id")
ON DELETE CASCADE
ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ChatSessionCallbackToken" ADD CONSTRAINT "ChatSessionCallbackToken_sourceSessionId_fkey" FOREIGN KEY("sourceSessionId") REFERENCES "ChatSession"("id")
ON DELETE
CASCADE
ON UPDATE CASCADE;

View File

@@ -1360,6 +1360,18 @@ files = [
dnspython = ">=2.0.0"
idna = ">=2.0.0"
[[package]]
name = "et-xmlfile"
version = "2.0.0"
description = "An implementation of lxml.xmlfile for the standard library"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
{file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
]
[[package]]
name = "exa-py"
version = "1.16.1"
@@ -4228,6 +4240,21 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<16)"]
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
[[package]]
name = "openpyxl"
version = "3.1.5"
description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
{file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
]
[package.dependencies]
et-xmlfile = "*"
[[package]]
name = "opentelemetry-api"
version = "1.39.1"
@@ -5430,6 +5457,66 @@ files = [
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
]
[[package]]
name = "pyarrow"
version = "23.0.1"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56"},
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c"},
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258"},
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2"},
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5"},
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222"},
{file = "pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d"},
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb"},
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350"},
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd"},
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9"},
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701"},
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78"},
{file = "pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919"},
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f"},
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7"},
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9"},
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05"},
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67"},
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730"},
{file = "pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0"},
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8"},
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f"},
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677"},
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2"},
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37"},
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2"},
{file = "pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a"},
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1"},
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500"},
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41"},
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07"},
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83"},
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125"},
{file = "pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8"},
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca"},
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1"},
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb"},
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1"},
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886"},
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f"},
{file = "pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5"},
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d"},
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f"},
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814"},
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d"},
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7"},
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690"},
{file = "pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce"},
{file = "pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019"},
]
[[package]]
name = "pyasn1"
version = "0.6.2"
@@ -8882,4 +8969,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
content-hash = "86dab25684dd46e635a33bd33281a926e5626a874ecc048c34389fecf34a87d8"

View File

@@ -37,7 +37,6 @@ jinja2 = "^3.1.6"
jsonref = "^1.1.0"
jsonschema = "^4.25.0"
langfuse = "^3.14.1"
markdown-it-py = "^3.0.0"
launchdarkly-server-sdk = "^9.14.1"
mem0ai = "^0.1.115"
moviepy = "^2.1.2"
@@ -93,6 +92,8 @@ gravitas-md2gdocs = "^0.1.0"
posthog = "^7.6.0"
fpdf2 = "^2.8.6"
langsmith = "^0.7.7"
openpyxl = "^3.1.5"
pyarrow = "^23.0.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"

View File

@@ -66,7 +66,6 @@ model User {
PendingHumanReviews PendingHumanReview[]
Workspace UserWorkspace?
ClaimedInvite InvitedUser? @relation("InvitedUserAuthUser")
ChatSessionCallbackTokens ChatSessionCallbackToken[]
// OAuth Provider relations
OAuthApplications OAuthApplication[]
@@ -88,13 +87,6 @@ enum TallyComputationStatus {
FAILED
}
enum ChatSessionStartType {
MANUAL
AUTOPILOT_NIGHTLY
AUTOPILOT_CALLBACK
AUTOPILOT_INVITE_CTA
}
model InvitedUser {
id String @id @default(uuid())
createdAt DateTime @default(now())
@@ -256,15 +248,6 @@ model ChatSession {
// Session metadata
title String?
credentials Json @default("{}") // Map of provider -> credential metadata
startType ChatSessionStartType @default(MANUAL)
executionTag String?
sessionConfig Json @default("{}") // ChatSessionConfig payload from backend.copilot.session_types.ChatSessionConfig
completionReport Json? // StoredCompletionReport payload from backend.copilot.session_types.StoredCompletionReport
completionReportRepairCount Int @default(0)
completionReportRepairQueuedAt DateTime?
completedAt DateTime?
notificationEmailSentAt DateTime?
notificationEmailSkippedAt DateTime?
// Rate limiting counters (stored as JSON maps)
successfulAgentRuns Json @default("{}") // Map of graph_id -> count
@@ -275,31 +258,8 @@ model ChatSession {
totalCompletionTokens Int @default(0)
Messages ChatMessage[]
CallbackTokens ChatSessionCallbackToken[] @relation("ChatSessionCallbackSource")
@@index([userId, updatedAt])
@@index([userId, startType, updatedAt])
@@unique([userId, executionTag])
}
model ChatSessionCallbackToken {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
sourceSessionId String?
SourceSession ChatSession? @relation("ChatSessionCallbackSource", fields: [sourceSessionId], references: [id], onDelete: Cascade)
callbackSessionMessage String
expiresAt DateTime
consumedAt DateTime?
consumedSessionId String?
@@index([userId, expiresAt])
@@index([consumedSessionId])
}
model ChatMessage {

View File

@@ -1,318 +0,0 @@
"use client";
import { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Button } from "@/components/atoms/Button/Button";
import { Card } from "@/components/atoms/Card/Card";
import { Input } from "@/components/atoms/Input/Input";
import { CopilotUsersTable } from "../CopilotUsersTable/CopilotUsersTable";
import { useAdminCopilotPage } from "../../useAdminCopilotPage";
function getStartTypeLabel(startType: ChatSessionStartType) {
if (startType === ChatSessionStartType.AUTOPILOT_INVITE_CTA) {
return "CTA";
}
if (startType === ChatSessionStartType.AUTOPILOT_NIGHTLY) {
return "Nightly";
}
if (startType === ChatSessionStartType.AUTOPILOT_CALLBACK) {
return "Callback";
}
return startType;
}
const triggerOptions = [
{
label: "Trigger CTA",
description:
"Runs the beta invite CTA flow even if the user would not normally qualify.",
startType: ChatSessionStartType.AUTOPILOT_INVITE_CTA,
variant: "primary" as const,
},
{
label: "Trigger Nightly",
description:
"Runs the nightly proactive Autopilot flow immediately for the selected user.",
startType: ChatSessionStartType.AUTOPILOT_NIGHTLY,
variant: "outline" as const,
},
{
label: "Trigger Callback",
description:
"Runs the callback re-engagement flow without checking the normal callback cohort.",
startType: ChatSessionStartType.AUTOPILOT_CALLBACK,
variant: "secondary" as const,
},
];
export function AdminCopilotPage() {
const {
search,
selectedUser,
pendingTriggerType,
lastTriggeredSession,
lastEmailSweepResult,
searchedUsers,
searchErrorMessage,
isSearchingUsers,
isRefreshingUsers,
isTriggeringSession,
isSendingPendingEmails,
hasSearch,
setSearch,
handleSelectUser,
handleSendPendingEmails,
handleTriggerSession,
} = useAdminCopilotPage();
return (
<div className="mx-auto flex max-w-7xl flex-col gap-6 p-6">
<div className="flex flex-col gap-2">
<h1 className="text-3xl font-bold text-zinc-900">Copilot</h1>
<p className="max-w-3xl text-sm text-zinc-600">
Manually create CTA, Nightly, or Callback Copilot sessions for a
specific user. These controls bypass the normal eligibility checks so
you can test each flow directly.
</p>
</div>
<div className="grid gap-6 xl:grid-cols-[minmax(0,1.35fr),24rem]">
<Card className="border border-zinc-200 shadow-sm">
<div className="flex flex-col gap-4">
<Input
id="copilot-user-search"
label="Search users"
hint="Results update as you type"
placeholder="Search by email, name, or user ID"
value={search}
onChange={(event) => setSearch(event.target.value)}
/>
{searchErrorMessage ? (
<p className="-mt-2 text-sm text-red-500">{searchErrorMessage}</p>
) : null}
<CopilotUsersTable
users={searchedUsers}
isLoading={isSearchingUsers}
isRefreshing={isRefreshingUsers}
hasSearch={hasSearch}
selectedUserId={selectedUser?.id ?? null}
onSelectUser={handleSelectUser}
/>
</div>
</Card>
<div className="flex flex-col gap-6">
<Card className="border border-zinc-200 shadow-sm">
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-3">
<h2 className="text-xl font-semibold text-zinc-900">
Selected user
</h2>
{selectedUser ? <Badge variant="info">Ready</Badge> : null}
</div>
{selectedUser ? (
<div className="flex flex-col gap-3 text-sm text-zinc-600">
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Email
</span>
<p className="mt-1 font-medium text-zinc-900">
{selectedUser.email}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Name
</span>
<p className="mt-1">
{selectedUser.name || "No display name"}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Timezone
</span>
<p className="mt-1">{selectedUser.timezone}</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
User ID
</span>
<p className="mt-1 break-all font-mono text-xs text-zinc-500">
{selectedUser.id}
</p>
</div>
</div>
) : (
<p className="text-sm text-zinc-500">
Select a user from the results table to enable manual Copilot
triggers.
</p>
)}
</div>
</Card>
<Card className="border border-zinc-200 shadow-sm">
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">
Trigger flows
</h2>
<p className="text-sm text-zinc-600">
Each action creates a new session immediately for the selected
user.
</p>
</div>
<div className="flex flex-col gap-3">
{triggerOptions.map((option) => (
<div
key={option.startType}
className="rounded-2xl border border-zinc-200 p-4"
>
<div className="flex flex-col gap-3">
<div className="flex flex-col gap-1">
<span className="font-medium text-zinc-900">
{option.label}
</span>
<p className="text-sm text-zinc-600">
{option.description}
</p>
</div>
<Button
variant={option.variant}
disabled={!selectedUser || isTriggeringSession}
loading={pendingTriggerType === option.startType}
onClick={() => handleTriggerSession(option.startType)}
>
{option.label}
</Button>
</div>
</div>
))}
</div>
</div>
</Card>
<Card className="border border-zinc-200 shadow-sm">
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">
Email follow-up
</h2>
<p className="text-sm text-zinc-600">
Run the pending Copilot completion-email sweep immediately for
the selected user.
</p>
</div>
<Button
variant="secondary"
disabled={!selectedUser || isSendingPendingEmails}
loading={isSendingPendingEmails}
onClick={handleSendPendingEmails}
>
Send pending emails
</Button>
{selectedUser && lastEmailSweepResult ? (
<div className="rounded-2xl border border-zinc-200 p-4 text-sm text-zinc-600">
<p className="font-medium text-zinc-900">
Last sweep for {selectedUser.email}
</p>
<div className="mt-3 grid gap-3 sm:grid-cols-2">
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Candidates
</span>
<p className="mt-1 text-zinc-900">
{lastEmailSweepResult.candidate_count}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Processed
</span>
<p className="mt-1 text-zinc-900">
{lastEmailSweepResult.processed_count}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Sent
</span>
<p className="mt-1 text-zinc-900">
{lastEmailSweepResult.sent_count}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Skipped
</span>
<p className="mt-1 text-zinc-900">
{lastEmailSweepResult.skipped_count}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Repairs queued
</span>
<p className="mt-1 text-zinc-900">
{lastEmailSweepResult.repair_queued_count}
</p>
</div>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Running / failed
</span>
<p className="mt-1 text-zinc-900">
{lastEmailSweepResult.running_count} /{" "}
{lastEmailSweepResult.failed_count}
</p>
</div>
</div>
</div>
) : null}
</div>
</Card>
{selectedUser && lastTriggeredSession ? (
<Card className="border border-zinc-200 shadow-sm">
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-3">
<h2 className="text-xl font-semibold text-zinc-900">
Latest session
</h2>
<Badge variant="success">
{getStartTypeLabel(lastTriggeredSession.start_type)}
</Badge>
</div>
<p className="text-sm text-zinc-600">
A new Copilot session was created for {selectedUser.email}.
</p>
<div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
Session ID
</span>
<p className="mt-1 break-all font-mono text-xs text-zinc-500">
{lastTriggeredSession.session_id}
</p>
</div>
<Button
as="NextLink"
href={`/copilot?sessionId=${lastTriggeredSession.session_id}&showAutopilot=1`}
target="_blank"
rel="noreferrer"
>
Open session
</Button>
</div>
</Card>
) : null}
</div>
</div>
</div>
);
}

View File

@@ -1,121 +0,0 @@
"use client";
import type { AdminCopilotUserSummary } from "@/app/api/__generated__/models/adminCopilotUserSummary";
import { Button } from "@/components/atoms/Button/Button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/__legacy__/ui/table";
interface Props {
users: AdminCopilotUserSummary[];
isLoading: boolean;
isRefreshing: boolean;
hasSearch: boolean;
selectedUserId: string | null;
onSelectUser: (user: AdminCopilotUserSummary) => void;
}
function formatDate(value: Date) {
return value.toLocaleString();
}
export function CopilotUsersTable({
users,
isLoading,
isRefreshing,
hasSearch,
selectedUserId,
onSelectUser,
}: Props) {
let emptyMessage = "Search by email, name, or user ID to find a user.";
if (hasSearch && isLoading) {
emptyMessage = "Searching users...";
} else if (hasSearch) {
emptyMessage = "No matching users found.";
}
return (
<div className="flex flex-col gap-4">
<div className="flex items-center justify-between gap-4">
<div className="flex flex-col gap-1">
<h2 className="text-xl font-semibold text-zinc-900">User results</h2>
<p className="text-sm text-zinc-600">
Select an existing user, then run an Autopilot flow manually.
</p>
</div>
<span className="text-xs uppercase tracking-[0.18em] text-zinc-400">
{isRefreshing
? "Refreshing"
: `${users.length} result${users.length === 1 ? "" : "s"}`}
</span>
</div>
<div className="overflow-hidden rounded-2xl border border-zinc-200">
<Table>
<TableHeader className="bg-zinc-50">
<TableRow>
<TableHead>User</TableHead>
<TableHead>Timezone</TableHead>
<TableHead>Updated</TableHead>
<TableHead className="text-right">Action</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{users.length === 0 ? (
<TableRow>
<TableCell
colSpan={4}
className="py-10 text-center text-zinc-500"
>
{emptyMessage}
</TableCell>
</TableRow>
) : (
users.map((user) => (
<TableRow key={user.id} className="align-top">
<TableCell>
<div className="flex flex-col gap-1">
<span className="font-medium text-zinc-900">
{user.email}
</span>
<span className="text-sm text-zinc-600">
{user.name || "No display name"}
</span>
<span className="font-mono text-xs text-zinc-400">
{user.id}
</span>
</div>
</TableCell>
<TableCell className="text-sm text-zinc-600">
{user.timezone}
</TableCell>
<TableCell className="text-sm text-zinc-600">
{formatDate(user.updated_at)}
</TableCell>
<TableCell>
<div className="flex justify-end">
<Button
variant={
user.id === selectedUserId ? "secondary" : "outline"
}
size="small"
onClick={() => onSelectUser(user)}
>
{user.id === selectedUserId ? "Selected" : "Select"}
</Button>
</div>
</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
</div>
</div>
);
}

View File

@@ -1,13 +0,0 @@
import { withRoleAccess } from "@/lib/withRoleAccess";
import { AdminCopilotPage } from "./components/AdminCopilotPage/AdminCopilotPage";
function AdminCopilot() {
return <AdminCopilotPage />;
}
export default async function AdminCopilotRoute() {
"use server";
const withAdminAccess = await withRoleAccess(["admin"]);
const ProtectedAdminCopilot = await withAdminAccess(AdminCopilot);
return <ProtectedAdminCopilot />;
}

View File

@@ -1,182 +0,0 @@
"use client";
import type { AdminCopilotUserSummary } from "@/app/api/__generated__/models/adminCopilotUserSummary";
import { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
import type { SendCopilotEmailsResponse } from "@/app/api/__generated__/models/sendCopilotEmailsResponse";
import type { TriggerCopilotSessionResponse } from "@/app/api/__generated__/models/triggerCopilotSessionResponse";
import { okData } from "@/app/api/helpers";
import { customMutator } from "@/app/api/mutators/custom-mutator";
import {
useGetV2SearchCopilotUsers,
usePostV2TriggerCopilotSession,
} from "@/app/api/__generated__/endpoints/admin/admin";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { ApiError } from "@/lib/autogpt-server-api/helpers";
import { useMutation } from "@tanstack/react-query";
import { useDeferredValue, useState } from "react";
function getErrorMessage(error: unknown) {
if (error instanceof ApiError) {
if (
typeof error.response === "object" &&
error.response !== null &&
"detail" in error.response &&
typeof error.response.detail === "string"
) {
return error.response.detail;
}
return error.message;
}
if (error instanceof Error) {
return error.message;
}
return "Something went wrong";
}
export function useAdminCopilotPage() {
const { toast } = useToast();
const [search, setSearch] = useState("");
const [selectedUser, setSelectedUser] =
useState<AdminCopilotUserSummary | null>(null);
const [pendingTriggerType, setPendingTriggerType] =
useState<ChatSessionStartType | null>(null);
const [lastTriggeredSession, setLastTriggeredSession] =
useState<TriggerCopilotSessionResponse | null>(null);
const [lastEmailSweepResult, setLastEmailSweepResult] =
useState<SendCopilotEmailsResponse | null>(null);
const deferredSearch = useDeferredValue(search);
const normalizedSearch = deferredSearch.trim();
const searchUsersQuery = useGetV2SearchCopilotUsers(
normalizedSearch ? { search: normalizedSearch, limit: 20 } : undefined,
{
query: {
enabled: normalizedSearch.length > 0,
select: okData,
},
},
);
const triggerCopilotSessionMutation = usePostV2TriggerCopilotSession({
mutation: {
onSuccess: (response) => {
setPendingTriggerType(null);
const session = okData(response) ?? null;
setLastTriggeredSession(session);
toast({
title: "Copilot session created",
variant: "default",
});
},
onError: (error) => {
setPendingTriggerType(null);
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
},
});
const sendPendingCopilotEmailsMutation = useMutation({
mutationKey: ["sendPendingCopilotEmails"],
mutationFn: async (userId: string) =>
customMutator<{
data: SendCopilotEmailsResponse;
status: number;
headers: Headers;
}>("/api/users/admin/copilot/send-emails", {
method: "POST",
body: JSON.stringify({ user_id: userId }),
}),
onSuccess: (response) => {
const result = okData(response) ?? null;
setLastEmailSweepResult(result);
if (!result) {
toast({
title: "Email sweep completed",
variant: "default",
});
return;
}
toast({
title:
result.sent_count > 0
? `Sent ${result.sent_count} Copilot email${result.sent_count === 1 ? "" : "s"}`
: "Email sweep completed",
description: [
`${result.candidate_count} candidate${result.candidate_count === 1 ? "" : "s"}`,
`${result.sent_count} sent`,
`${result.skipped_count} skipped`,
`${result.repair_queued_count} repairs queued`,
`${result.running_count} still running`,
`${result.failed_count} failed`,
].join(" • "),
variant: "default",
});
},
onError: (error: unknown) => {
toast({
title: getErrorMessage(error),
variant: "destructive",
});
},
});
function handleSelectUser(user: AdminCopilotUserSummary) {
setSelectedUser(user);
setLastTriggeredSession(null);
setLastEmailSweepResult(null);
}
function handleTriggerSession(startType: ChatSessionStartType) {
if (!selectedUser) {
return;
}
setPendingTriggerType(startType);
setLastTriggeredSession(null);
triggerCopilotSessionMutation.mutate({
data: {
user_id: selectedUser.id,
start_type: startType,
},
});
}
function handleSendPendingEmails() {
if (!selectedUser) {
return;
}
setLastEmailSweepResult(null);
sendPendingCopilotEmailsMutation.mutate(selectedUser.id);
}
return {
search,
selectedUser,
pendingTriggerType,
lastTriggeredSession,
lastEmailSweepResult,
searchedUsers: searchUsersQuery.data?.users ?? [],
searchErrorMessage: searchUsersQuery.error
? getErrorMessage(searchUsersQuery.error)
: null,
isSearchingUsers: searchUsersQuery.isLoading,
isRefreshingUsers:
searchUsersQuery.isFetching && !searchUsersQuery.isLoading,
isTriggeringSession: triggerCopilotSessionMutation.isPending,
isSendingPendingEmails: sendPendingCopilotEmailsMutation.isPending,
hasSearch: normalizedSearch.length > 0,
setSearch,
handleSelectUser,
handleTriggerSession,
handleSendPendingEmails,
};
}

View File

@@ -8,7 +8,6 @@ import {
MagnifyingGlassIcon,
FileTextIcon,
SlidersHorizontalIcon,
LightningIcon,
} from "@phosphor-icons/react";
const sidebarLinkGroups = [
@@ -34,11 +33,6 @@ const sidebarLinkGroups = [
href: "/admin/impersonation",
icon: <MagnifyingGlassIcon size={24} />,
},
{
text: "Copilot",
href: "/admin/copilot",
icon: <LightningIcon size={24} />,
},
{
text: "Execution Analytics",
href: "/admin/execution-analytics",

View File

@@ -0,0 +1,440 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { screen, cleanup } from "@testing-library/react";
import { render } from "@/tests/integrations/test-utils";
import React from "react";
import { BlockUIType } from "../components/types";
import type {
CustomNodeData,
CustomNode as CustomNodeType,
} from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import type { NodeProps } from "@xyflow/react";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
// ---- Mock sub-components ----
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeContainer",
() => ({
NodeContainer: ({
children,
hasErrors,
}: {
children: React.ReactNode;
hasErrors: boolean;
}) => (
<div data-testid="node-container" data-has-errors={String(!!hasErrors)}>
{children}
</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader",
() => ({
NodeHeader: ({ data }: { data: CustomNodeData }) => (
<div data-testid="node-header">{data.title}</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/StickyNoteBlock",
() => ({
StickyNoteBlock: ({ data }: { data: CustomNodeData }) => (
<div data-testid="sticky-note-block">{data.title}</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeAdvancedToggle",
() => ({
NodeAdvancedToggle: () => <div data-testid="node-advanced-toggle" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput",
() => ({
NodeDataRenderer: () => <div data-testid="node-data-renderer" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeExecutionBadge",
() => ({
NodeExecutionBadge: () => <div data-testid="node-execution-badge" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeRightClickMenu",
() => ({
NodeRightClickMenu: ({ children }: { children: React.ReactNode }) => (
<div data-testid="node-right-click-menu">{children}</div>
),
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/WebhookDisclaimer",
() => ({
WebhookDisclaimer: () => <div data-testid="webhook-disclaimer" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/SubAgentUpdateFeature",
() => ({
SubAgentUpdateFeature: () => <div data-testid="sub-agent-update" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/AyrshareConnectButton",
() => ({
AyrshareConnectButton: () => <div data-testid="ayrshare-connect-button" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/FormCreator",
() => ({
FormCreator: () => <div data-testid="form-creator" />,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/OutputHandler",
() => ({
OutputHandler: () => <div data-testid="output-handler" />,
}),
);
vi.mock(
"@/components/renderers/InputRenderer/utils/input-schema-pre-processor",
() => ({
preprocessInputSchema: (schema: unknown) => schema,
}),
);
vi.mock(
"@/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode",
() => ({
useCustomNode: ({ data }: { data: CustomNodeData }) => ({
inputSchema: data.inputSchema,
outputSchema: data.outputSchema,
isMCPWithTool: false,
}),
}),
);
vi.mock("@xyflow/react", async () => {
const actual = await vi.importActual("@xyflow/react");
return {
...actual,
useReactFlow: () => ({
getNodes: () => [],
getEdges: () => [],
setNodes: vi.fn(),
setEdges: vi.fn(),
getNode: vi.fn(),
}),
useNodeId: () => "test-node-id",
useUpdateNodeInternals: () => vi.fn(),
Handle: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
Position: { Left: "left", Right: "right", Top: "top", Bottom: "bottom" },
};
});
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
// ---- Helpers ----
function buildNodeData(
overrides: Partial<CustomNodeData> = {},
): CustomNodeData {
return {
hardcodedValues: {},
title: "Test Block",
description: "A test block",
inputSchema: { type: "object", properties: {} },
outputSchema: { type: "object", properties: {} },
uiType: BlockUIType.STANDARD,
block_id: "block-123",
costs: [],
categories: [],
...overrides,
};
}
function buildNodeProps(
dataOverrides: Partial<CustomNodeData> = {},
propsOverrides: Partial<NodeProps<CustomNodeType>> = {},
): NodeProps<CustomNodeType> {
return {
id: "node-1",
data: buildNodeData(dataOverrides),
selected: false,
type: "custom",
isConnectable: true,
positionAbsoluteX: 0,
positionAbsoluteY: 0,
zIndex: 0,
dragging: false,
dragHandle: undefined,
draggable: true,
selectable: true,
deletable: true,
parentId: undefined,
width: undefined,
height: undefined,
sourcePosition: undefined,
targetPosition: undefined,
...propsOverrides,
};
}
function renderCustomNode(
dataOverrides: Partial<CustomNodeData> = {},
propsOverrides: Partial<NodeProps<CustomNodeType>> = {},
) {
const props = buildNodeProps(dataOverrides, propsOverrides);
return render(<CustomNode {...props} />);
}
function createExecutionResult(
overrides: Partial<NodeExecutionResult> = {},
): NodeExecutionResult {
return {
node_exec_id: overrides.node_exec_id ?? "exec-1",
node_id: overrides.node_id ?? "node-1",
graph_exec_id: overrides.graph_exec_id ?? "graph-exec-1",
graph_id: overrides.graph_id ?? "graph-1",
graph_version: overrides.graph_version ?? 1,
user_id: overrides.user_id ?? "test-user",
block_id: overrides.block_id ?? "block-1",
status: overrides.status ?? "COMPLETED",
input_data: overrides.input_data ?? {},
output_data: overrides.output_data ?? {},
add_time: overrides.add_time ?? new Date("2024-01-01T00:00:00Z"),
queue_time: overrides.queue_time ?? new Date("2024-01-01T00:00:00Z"),
start_time: overrides.start_time ?? new Date("2024-01-01T00:00:01Z"),
end_time: overrides.end_time ?? new Date("2024-01-01T00:00:02Z"),
};
}
// ---- Tests ----
beforeEach(() => {
cleanup();
});
describe("CustomNode", () => {
describe("STANDARD type rendering", () => {
it("renders NodeHeader with the block title", () => {
renderCustomNode({ title: "My Standard Block" });
const header = screen.getByTestId("node-header");
expect(header).toBeDefined();
expect(header.textContent).toContain("My Standard Block");
});
it("renders NodeContainer, FormCreator, OutputHandler, and NodeExecutionBadge", () => {
renderCustomNode();
expect(screen.getByTestId("node-container")).toBeDefined();
expect(screen.getByTestId("form-creator")).toBeDefined();
expect(screen.getByTestId("output-handler")).toBeDefined();
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
expect(screen.getByTestId("node-data-renderer")).toBeDefined();
expect(screen.getByTestId("node-advanced-toggle")).toBeDefined();
});
it("wraps content in NodeRightClickMenu", () => {
renderCustomNode();
expect(screen.getByTestId("node-right-click-menu")).toBeDefined();
});
it("does not render StickyNoteBlock for STANDARD type", () => {
renderCustomNode();
expect(screen.queryByTestId("sticky-note-block")).toBeNull();
});
});
describe("NOTE type rendering", () => {
it("renders StickyNoteBlock instead of main UI", () => {
renderCustomNode({ uiType: BlockUIType.NOTE, title: "My Note" });
const note = screen.getByTestId("sticky-note-block");
expect(note).toBeDefined();
expect(note.textContent).toContain("My Note");
});
it("does not render NodeContainer or other standard components", () => {
renderCustomNode({ uiType: BlockUIType.NOTE });
expect(screen.queryByTestId("node-container")).toBeNull();
expect(screen.queryByTestId("node-header")).toBeNull();
expect(screen.queryByTestId("form-creator")).toBeNull();
expect(screen.queryByTestId("output-handler")).toBeNull();
});
});
describe("WEBHOOK type rendering", () => {
it("renders WebhookDisclaimer for WEBHOOK type", () => {
renderCustomNode({ uiType: BlockUIType.WEBHOOK });
expect(screen.getByTestId("webhook-disclaimer")).toBeDefined();
});
it("renders WebhookDisclaimer for WEBHOOK_MANUAL type", () => {
renderCustomNode({ uiType: BlockUIType.WEBHOOK_MANUAL });
expect(screen.getByTestId("webhook-disclaimer")).toBeDefined();
});
});
describe("AGENT type rendering", () => {
it("renders SubAgentUpdateFeature for AGENT type", () => {
renderCustomNode({ uiType: BlockUIType.AGENT });
expect(screen.getByTestId("sub-agent-update")).toBeDefined();
});
it("does not render SubAgentUpdateFeature for non-AGENT types", () => {
renderCustomNode({ uiType: BlockUIType.STANDARD });
expect(screen.queryByTestId("sub-agent-update")).toBeNull();
});
});
describe("OUTPUT type rendering", () => {
it("does not render OutputHandler for OUTPUT type", () => {
renderCustomNode({ uiType: BlockUIType.OUTPUT });
expect(screen.queryByTestId("output-handler")).toBeNull();
});
it("still renders FormCreator and other components for OUTPUT type", () => {
renderCustomNode({ uiType: BlockUIType.OUTPUT });
expect(screen.getByTestId("form-creator")).toBeDefined();
expect(screen.getByTestId("node-header")).toBeDefined();
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
});
describe("AYRSHARE type rendering", () => {
it("renders AyrshareConnectButton for AYRSHARE type", () => {
renderCustomNode({ uiType: BlockUIType.AYRSHARE });
expect(screen.getByTestId("ayrshare-connect-button")).toBeDefined();
});
it("does not render AyrshareConnectButton for non-AYRSHARE types", () => {
renderCustomNode({ uiType: BlockUIType.STANDARD });
expect(screen.queryByTestId("ayrshare-connect-button")).toBeNull();
});
});
describe("error states", () => {
it("sets hasErrors on NodeContainer when data.errors has non-empty values", () => {
renderCustomNode({
errors: { field1: "This field is required" },
});
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("true");
});
it("does not set hasErrors when data.errors is empty", () => {
renderCustomNode({ errors: {} });
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("does not set hasErrors when data.errors values are all empty strings", () => {
renderCustomNode({ errors: { field1: "" } });
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("sets hasErrors when last execution result has error in output_data", () => {
renderCustomNode({
nodeExecutionResults: [
createExecutionResult({
output_data: { error: ["Something went wrong"] },
}),
],
});
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("true");
});
it("does not set hasErrors when execution results have no error", () => {
renderCustomNode({
nodeExecutionResults: [
createExecutionResult({
output_data: { result: ["success"] },
}),
],
});
const container = screen.getByTestId("node-container");
expect(container.getAttribute("data-has-errors")).toBe("false");
});
});
describe("NodeExecutionBadge", () => {
it("always renders NodeExecutionBadge for non-NOTE types", () => {
renderCustomNode({ uiType: BlockUIType.STANDARD });
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
it("renders NodeExecutionBadge for AGENT type", () => {
renderCustomNode({ uiType: BlockUIType.AGENT });
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
it("renders NodeExecutionBadge for OUTPUT type", () => {
renderCustomNode({ uiType: BlockUIType.OUTPUT });
expect(screen.getByTestId("node-execution-badge")).toBeDefined();
});
});
describe("edge cases", () => {
it("renders without nodeExecutionResults", () => {
renderCustomNode({ nodeExecutionResults: undefined });
const container = screen.getByTestId("node-container");
expect(container).toBeDefined();
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("renders without errors property", () => {
renderCustomNode({ errors: undefined });
const container = screen.getByTestId("node-container");
expect(container).toBeDefined();
expect(container.getAttribute("data-has-errors")).toBe("false");
});
it("renders with empty execution results array", () => {
renderCustomNode({ nodeExecutionResults: [] });
const container = screen.getByTestId("node-container");
expect(container).toBeDefined();
expect(container.getAttribute("data-has-errors")).toBe("false");
});
});
});

View File

@@ -0,0 +1,342 @@
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { useBlockMenuStore } from "../stores/blockMenuStore";
import { useControlPanelStore } from "../stores/controlPanelStore";
import { DefaultStateType } from "../components/NewControlPanel/NewBlockMenu/types";
import { SearchEntryFilterAnyOfItem } from "@/app/api/__generated__/models/searchEntryFilterAnyOfItem";
// ---------------------------------------------------------------------------
// Mocks for heavy child components
// ---------------------------------------------------------------------------
vi.mock(
"../components/NewControlPanel/NewBlockMenu/BlockMenuDefault/BlockMenuDefault",
() => ({
BlockMenuDefault: () => (
<div data-testid="block-menu-default">Default Content</div>
),
}),
);
vi.mock(
"../components/NewControlPanel/NewBlockMenu/BlockMenuSearch/BlockMenuSearch",
() => ({
BlockMenuSearch: () => (
<div data-testid="block-menu-search">Search Results</div>
),
}),
);
// Mock query client used by the search bar hook
vi.mock("@/lib/react-query/queryClient", () => ({
getQueryClient: () => ({
invalidateQueries: vi.fn(),
}),
}));
// ---------------------------------------------------------------------------
// Reset stores before each test
// ---------------------------------------------------------------------------
afterEach(() => {
cleanup();
});
beforeEach(() => {
useBlockMenuStore.getState().reset();
useBlockMenuStore.setState({
filters: [],
creators: [],
creators_list: [],
categoryCounts: {
blocks: 0,
integrations: 0,
marketplace_agents: 0,
my_agents: 0,
},
});
useControlPanelStore.getState().reset();
});
// ===========================================================================
// Section 1: blockMenuStore unit tests
// ===========================================================================
describe("blockMenuStore", () => {
describe("searchQuery", () => {
it("defaults to an empty string", () => {
expect(useBlockMenuStore.getState().searchQuery).toBe("");
});
it("sets the search query", () => {
useBlockMenuStore.getState().setSearchQuery("timer");
expect(useBlockMenuStore.getState().searchQuery).toBe("timer");
});
});
describe("defaultState", () => {
it("defaults to SUGGESTION", () => {
expect(useBlockMenuStore.getState().defaultState).toBe(
DefaultStateType.SUGGESTION,
);
});
it("sets the default state", () => {
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
expect(useBlockMenuStore.getState().defaultState).toBe(
DefaultStateType.ALL_BLOCKS,
);
});
});
describe("filters", () => {
it("defaults to an empty array", () => {
expect(useBlockMenuStore.getState().filters).toEqual([]);
});
it("adds a filter", () => {
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.blocks,
]);
});
it("removes a filter", () => {
useBlockMenuStore
.getState()
.setFilters([
SearchEntryFilterAnyOfItem.blocks,
SearchEntryFilterAnyOfItem.integrations,
]);
useBlockMenuStore
.getState()
.removeFilter(SearchEntryFilterAnyOfItem.blocks);
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.integrations,
]);
});
it("replaces all filters with setFilters", () => {
useBlockMenuStore.getState().addFilter(SearchEntryFilterAnyOfItem.blocks);
useBlockMenuStore
.getState()
.setFilters([SearchEntryFilterAnyOfItem.marketplace_agents]);
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.marketplace_agents,
]);
});
});
describe("creators", () => {
it("adds a creator", () => {
useBlockMenuStore.getState().addCreator("alice");
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
});
it("removes a creator", () => {
useBlockMenuStore.getState().setCreators(["alice", "bob"]);
useBlockMenuStore.getState().removeCreator("alice");
expect(useBlockMenuStore.getState().creators).toEqual(["bob"]);
});
it("replaces all creators with setCreators", () => {
useBlockMenuStore.getState().addCreator("alice");
useBlockMenuStore.getState().setCreators(["charlie"]);
expect(useBlockMenuStore.getState().creators).toEqual(["charlie"]);
});
});
describe("categoryCounts", () => {
it("sets category counts", () => {
const counts = {
blocks: 10,
integrations: 5,
marketplace_agents: 3,
my_agents: 2,
};
useBlockMenuStore.getState().setCategoryCounts(counts);
expect(useBlockMenuStore.getState().categoryCounts).toEqual(counts);
});
});
describe("searchId", () => {
it("defaults to undefined", () => {
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
});
it("sets and clears searchId", () => {
useBlockMenuStore.getState().setSearchId("search-123");
expect(useBlockMenuStore.getState().searchId).toBe("search-123");
useBlockMenuStore.getState().setSearchId(undefined);
expect(useBlockMenuStore.getState().searchId).toBeUndefined();
});
});
describe("integration", () => {
it("defaults to undefined", () => {
expect(useBlockMenuStore.getState().integration).toBeUndefined();
});
it("sets the integration", () => {
useBlockMenuStore.getState().setIntegration("slack");
expect(useBlockMenuStore.getState().integration).toBe("slack");
});
});
describe("reset", () => {
it("resets searchQuery, searchId, defaultState, and integration", () => {
useBlockMenuStore.getState().setSearchQuery("hello");
useBlockMenuStore.getState().setSearchId("id-1");
useBlockMenuStore.getState().setDefaultState(DefaultStateType.ALL_BLOCKS);
useBlockMenuStore.getState().setIntegration("github");
useBlockMenuStore.getState().reset();
const state = useBlockMenuStore.getState();
expect(state.searchQuery).toBe("");
expect(state.searchId).toBeUndefined();
expect(state.defaultState).toBe(DefaultStateType.SUGGESTION);
expect(state.integration).toBeUndefined();
});
it("does not reset filters or creators (by design)", () => {
useBlockMenuStore
.getState()
.setFilters([SearchEntryFilterAnyOfItem.blocks]);
useBlockMenuStore.getState().setCreators(["alice"]);
useBlockMenuStore.getState().reset();
expect(useBlockMenuStore.getState().filters).toEqual([
SearchEntryFilterAnyOfItem.blocks,
]);
expect(useBlockMenuStore.getState().creators).toEqual(["alice"]);
});
});
});
// ===========================================================================
// Section 2: controlPanelStore unit tests
// ===========================================================================
describe("controlPanelStore", () => {
it("defaults blockMenuOpen to false", () => {
expect(useControlPanelStore.getState().blockMenuOpen).toBe(false);
});
it("sets blockMenuOpen", () => {
useControlPanelStore.getState().setBlockMenuOpen(true);
expect(useControlPanelStore.getState().blockMenuOpen).toBe(true);
});
it("sets forceOpenBlockMenu", () => {
useControlPanelStore.getState().setForceOpenBlockMenu(true);
expect(useControlPanelStore.getState().forceOpenBlockMenu).toBe(true);
});
it("resets all control panel state", () => {
useControlPanelStore.getState().setBlockMenuOpen(true);
useControlPanelStore.getState().setForceOpenBlockMenu(true);
useControlPanelStore.getState().setSaveControlOpen(true);
useControlPanelStore.getState().setForceOpenSave(true);
useControlPanelStore.getState().reset();
const state = useControlPanelStore.getState();
expect(state.blockMenuOpen).toBe(false);
expect(state.forceOpenBlockMenu).toBe(false);
expect(state.saveControlOpen).toBe(false);
expect(state.forceOpenSave).toBe(false);
});
});
// ===========================================================================
// Section 3: BlockMenuContent integration tests
// ===========================================================================
// We import BlockMenuContent directly to avoid dealing with the Popover wrapper.
import { BlockMenuContent } from "../components/NewControlPanel/NewBlockMenu/BlockMenuContent/BlockMenuContent";
describe("BlockMenuContent", () => {
it("shows BlockMenuDefault when there is no search query", () => {
useBlockMenuStore.getState().setSearchQuery("");
render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-default")).toBeDefined();
expect(screen.queryByTestId("block-menu-search")).toBeNull();
});
it("shows BlockMenuSearch when a search query is present", () => {
useBlockMenuStore.getState().setSearchQuery("timer");
render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-search")).toBeDefined();
expect(screen.queryByTestId("block-menu-default")).toBeNull();
});
it("renders the search bar", () => {
render(<BlockMenuContent />);
expect(
screen.getByPlaceholderText(
"Blocks, Agents, Integrations or Keywords...",
),
).toBeDefined();
});
it("switches from default to search view when store query changes", () => {
const { rerender } = render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-default")).toBeDefined();
// Simulate typing by setting the store directly
useBlockMenuStore.getState().setSearchQuery("webhook");
rerender(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-search")).toBeDefined();
expect(screen.queryByTestId("block-menu-default")).toBeNull();
});
it("switches back to default view when search query is cleared", () => {
useBlockMenuStore.getState().setSearchQuery("something");
const { rerender } = render(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-search")).toBeDefined();
useBlockMenuStore.getState().setSearchQuery("");
rerender(<BlockMenuContent />);
expect(screen.getByTestId("block-menu-default")).toBeDefined();
expect(screen.queryByTestId("block-menu-search")).toBeNull();
});
it("typing in the search bar updates the local input value", async () => {
render(<BlockMenuContent />);
const input = screen.getByPlaceholderText(
"Blocks, Agents, Integrations or Keywords...",
);
fireEvent.change(input, { target: { value: "slack" } });
expect((input as HTMLInputElement).value).toBe("slack");
});
it("shows clear button when input has text and clears on click", async () => {
render(<BlockMenuContent />);
const input = screen.getByPlaceholderText(
"Blocks, Agents, Integrations or Keywords...",
);
fireEvent.change(input, { target: { value: "test" } });
// The clear button should appear
const clearButton = screen.getByRole("button");
fireEvent.click(clearButton);
await waitFor(() => {
expect((input as HTMLInputElement).value).toBe("");
});
});
});

View File

@@ -0,0 +1,270 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { UseFormReturn, useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import * as z from "zod";
import { renderHook } from "@testing-library/react";
import { useControlPanelStore } from "../stores/controlPanelStore";
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
import { NewSaveControl } from "../components/NewControlPanel/NewSaveControl/NewSaveControl";
import { useNewSaveControl } from "../components/NewControlPanel/NewSaveControl/useNewSaveControl";
const formSchema = z.object({
name: z.string().min(1, "Name is required").max(100),
description: z.string().max(500),
});
type SaveableGraphFormValues = z.infer<typeof formSchema>;
const mockHandleSave = vi.fn();
vi.mock(
"../components/NewControlPanel/NewSaveControl/useNewSaveControl",
() => ({
useNewSaveControl: vi.fn(),
}),
);
const mockUseNewSaveControl = vi.mocked(useNewSaveControl);
function createMockForm(
defaults: SaveableGraphFormValues = { name: "", description: "" },
): UseFormReturn<SaveableGraphFormValues> {
const { result } = renderHook(() =>
useForm<SaveableGraphFormValues>({
resolver: zodResolver(formSchema),
defaultValues: defaults,
}),
);
return result.current;
}
function setupMock(overrides: {
isSaving?: boolean;
graphVersion?: number;
name?: string;
description?: string;
}) {
const form = createMockForm({
name: overrides.name ?? "",
description: overrides.description ?? "",
});
mockUseNewSaveControl.mockReturnValue({
form,
isSaving: overrides.isSaving ?? false,
graphVersion: overrides.graphVersion,
handleSave: mockHandleSave,
});
return form;
}
function resetStore() {
useControlPanelStore.setState({
blockMenuOpen: false,
saveControlOpen: false,
forceOpenBlockMenu: false,
forceOpenSave: false,
});
}
beforeEach(() => {
cleanup();
resetStore();
mockHandleSave.mockReset();
});
afterEach(() => {
cleanup();
});
describe("NewSaveControl", () => {
it("renders save button trigger", () => {
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByTestId("save-control-save-button")).toBeDefined();
});
it("renders name and description inputs when popover is open", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByTestId("save-control-name-input")).toBeDefined();
expect(screen.getByTestId("save-control-description-input")).toBeDefined();
});
it("does not render popover content when closed", () => {
useControlPanelStore.setState({ saveControlOpen: false });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.queryByTestId("save-control-name-input")).toBeNull();
expect(screen.queryByTestId("save-control-description-input")).toBeNull();
});
it("shows version output when graphVersion is set", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ graphVersion: 3 });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const versionInput = screen.getByTestId("save-control-version-output");
expect(versionInput).toBeDefined();
expect((versionInput as HTMLInputElement).disabled).toBe(true);
});
it("hides version output when graphVersion is undefined", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ graphVersion: undefined });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.queryByTestId("save-control-version-output")).toBeNull();
});
it("enables save button when isSaving is false", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ isSaving: false });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByTestId("save-control-save-agent-button");
expect((saveButton as HTMLButtonElement).disabled).toBe(false);
});
it("disables save button when isSaving is true", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ isSaving: true });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByRole("button", { name: /save agent/i });
expect((saveButton as HTMLButtonElement).disabled).toBe(true);
});
it("calls handleSave on form submission with valid data", async () => {
useControlPanelStore.setState({ saveControlOpen: true });
const form = setupMock({ name: "My Agent", description: "A description" });
form.setValue("name", "My Agent");
form.setValue("description", "A description");
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByTestId("save-control-save-agent-button");
fireEvent.click(saveButton);
await waitFor(() => {
expect(mockHandleSave).toHaveBeenCalledWith(
{ name: "My Agent", description: "A description" },
expect.anything(),
);
});
});
it("does not call handleSave when name is empty (validation fails)", async () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({ name: "", description: "" });
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const saveButton = screen.getByTestId("save-control-save-agent-button");
fireEvent.click(saveButton);
await waitFor(() => {
expect(mockHandleSave).not.toHaveBeenCalled();
});
});
it("popover stays open when forceOpenSave is true", () => {
useControlPanelStore.setState({
saveControlOpen: false,
forceOpenSave: true,
});
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByTestId("save-control-name-input")).toBeDefined();
});
it("allows typing in name and description inputs", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
const nameInput = screen.getByTestId(
"save-control-name-input",
) as HTMLInputElement;
const descriptionInput = screen.getByTestId(
"save-control-description-input",
) as HTMLInputElement;
fireEvent.change(nameInput, { target: { value: "Test Agent" } });
fireEvent.change(descriptionInput, {
target: { value: "Test Description" },
});
expect(nameInput.value).toBe("Test Agent");
expect(descriptionInput.value).toBe("Test Description");
});
it("displays save button text", () => {
useControlPanelStore.setState({ saveControlOpen: true });
setupMock({});
render(
<TooltipProvider>
<NewSaveControl />
</TooltipProvider>,
);
expect(screen.getByText("Save Agent")).toBeDefined();
});
});

View File

@@ -0,0 +1,147 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { screen, fireEvent, cleanup } from "@testing-library/react";
import { render } from "@/tests/integrations/test-utils";
import React from "react";
import { useGraphStore } from "../stores/graphStore";
vi.mock(
"@/app/(platform)/build/components/BuilderActions/components/RunGraph/useRunGraph",
() => ({
useRunGraph: vi.fn(),
}),
);
vi.mock(
"@/app/(platform)/build/components/BuilderActions/components/RunInputDialog/RunInputDialog",
() => ({
RunInputDialog: ({ isOpen }: { isOpen: boolean }) =>
isOpen ? <div data-testid="run-input-dialog">Dialog</div> : null,
}),
);
// Must import after mocks
import { useRunGraph } from "../components/BuilderActions/components/RunGraph/useRunGraph";
import { RunGraph } from "../components/BuilderActions/components/RunGraph/RunGraph";
const mockUseRunGraph = vi.mocked(useRunGraph);
function createMockReturnValue(
overrides: Partial<ReturnType<typeof useRunGraph>> = {},
) {
return {
handleRunGraph: vi.fn(),
handleStopGraph: vi.fn(),
openRunInputDialog: false,
setOpenRunInputDialog: vi.fn(),
isExecutingGraph: false,
isTerminatingGraph: false,
isSaving: false,
...overrides,
};
}
// RunGraph uses Tooltip which requires TooltipProvider
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
function renderRunGraph(flowID: string | null = "test-flow-id") {
return render(
<TooltipProvider>
<RunGraph flowID={flowID} />
</TooltipProvider>,
);
}
describe("RunGraph", () => {
beforeEach(() => {
cleanup();
mockUseRunGraph.mockReturnValue(createMockReturnValue());
useGraphStore.setState({ isGraphRunning: false });
});
afterEach(() => {
cleanup();
});
it("renders an enabled button when flowID is provided", () => {
renderRunGraph("test-flow-id");
const button = screen.getByRole("button");
expect((button as HTMLButtonElement).disabled).toBe(false);
});
it("renders a disabled button when flowID is null", () => {
renderRunGraph(null);
const button = screen.getByRole("button");
expect((button as HTMLButtonElement).disabled).toBe(true);
});
it("disables the button when isExecutingGraph is true", () => {
mockUseRunGraph.mockReturnValue(
createMockReturnValue({ isExecutingGraph: true }),
);
renderRunGraph();
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
true,
);
});
it("disables the button when isTerminatingGraph is true", () => {
mockUseRunGraph.mockReturnValue(
createMockReturnValue({ isTerminatingGraph: true }),
);
renderRunGraph();
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
true,
);
});
it("disables the button when isSaving is true", () => {
mockUseRunGraph.mockReturnValue(createMockReturnValue({ isSaving: true }));
renderRunGraph();
expect((screen.getByRole("button") as HTMLButtonElement).disabled).toBe(
true,
);
});
it("uses data-id run-graph-button when not running", () => {
renderRunGraph();
const button = screen.getByRole("button");
expect(button.getAttribute("data-id")).toBe("run-graph-button");
});
it("uses data-id stop-graph-button when running", () => {
useGraphStore.setState({ isGraphRunning: true });
renderRunGraph();
const button = screen.getByRole("button");
expect(button.getAttribute("data-id")).toBe("stop-graph-button");
});
it("calls handleRunGraph when clicked and graph is not running", () => {
const handleRunGraph = vi.fn();
mockUseRunGraph.mockReturnValue(createMockReturnValue({ handleRunGraph }));
renderRunGraph();
fireEvent.click(screen.getByRole("button"));
expect(handleRunGraph).toHaveBeenCalledOnce();
});
it("calls handleStopGraph when clicked and graph is running", () => {
const handleStopGraph = vi.fn();
mockUseRunGraph.mockReturnValue(createMockReturnValue({ handleStopGraph }));
useGraphStore.setState({ isGraphRunning: true });
renderRunGraph();
fireEvent.click(screen.getByRole("button"));
expect(handleStopGraph).toHaveBeenCalledOnce();
});
it("renders RunInputDialog hidden by default", () => {
renderRunGraph();
expect(screen.queryByTestId("run-input-dialog")).toBeNull();
});
it("renders RunInputDialog when openRunInputDialog is true", () => {
mockUseRunGraph.mockReturnValue(
createMockReturnValue({ openRunInputDialog: true }),
);
renderRunGraph();
expect(screen.getByTestId("run-input-dialog")).toBeDefined();
});
});

View File

@@ -0,0 +1,257 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import { BlockUIType } from "../components/types";
vi.mock("@/services/storage/local-storage", () => {
const store: Record<string, string> = {};
return {
Key: { COPIED_FLOW_DATA: "COPIED_FLOW_DATA" },
storage: {
get: (key: string) => store[key] ?? null,
set: (key: string, value: string) => {
store[key] = value;
},
clean: (key: string) => {
delete store[key];
},
},
};
});
import { useCopyPasteStore } from "../stores/copyPasteStore";
import { useNodeStore } from "../stores/nodeStore";
import { useEdgeStore } from "../stores/edgeStore";
import { useHistoryStore } from "../stores/historyStore";
import { storage, Key } from "@/services/storage/local-storage";
function createTestNode(
id: string,
overrides: Partial<CustomNode> = {},
): CustomNode {
return {
id,
type: "custom",
position: overrides.position ?? { x: 100, y: 200 },
selected: overrides.selected,
data: {
hardcodedValues: {},
title: `Node ${id}`,
description: "test node",
inputSchema: {},
outputSchema: {},
uiType: BlockUIType.STANDARD,
block_id: `block-${id}`,
costs: [],
categories: [],
...overrides.data,
},
} as CustomNode;
}
describe("useCopyPasteStore", () => {
beforeEach(() => {
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
useEdgeStore.setState({ edges: [] });
useHistoryStore.getState().clear();
storage.clean(Key.COPIED_FLOW_DATA);
});
describe("copySelectedNodes", () => {
it("copies a single selected node to localStorage", () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
useCopyPasteStore.getState().copySelectedNodes();
const stored = storage.get(Key.COPIED_FLOW_DATA);
expect(stored).not.toBeNull();
const parsed = JSON.parse(stored!);
expect(parsed.nodes).toHaveLength(1);
expect(parsed.nodes[0].id).toBe("1");
expect(parsed.edges).toHaveLength(0);
});
it("copies only edges between selected nodes", () => {
const nodeA = createTestNode("a", { selected: true });
const nodeB = createTestNode("b", { selected: true });
const nodeC = createTestNode("c", { selected: false });
useNodeStore.setState({ nodes: [nodeA, nodeB, nodeC] });
useEdgeStore.setState({
edges: [
{
id: "e-ab",
source: "a",
target: "b",
sourceHandle: "out",
targetHandle: "in",
},
{
id: "e-bc",
source: "b",
target: "c",
sourceHandle: "out",
targetHandle: "in",
},
{
id: "e-ac",
source: "a",
target: "c",
sourceHandle: "out",
targetHandle: "in",
},
],
});
useCopyPasteStore.getState().copySelectedNodes();
const parsed = JSON.parse(storage.get(Key.COPIED_FLOW_DATA)!);
expect(parsed.nodes).toHaveLength(2);
expect(parsed.edges).toHaveLength(1);
expect(parsed.edges[0].id).toBe("e-ab");
});
it("stores empty data when no nodes are selected", () => {
const node = createTestNode("1", { selected: false });
useNodeStore.setState({ nodes: [node] });
useCopyPasteStore.getState().copySelectedNodes();
const parsed = JSON.parse(storage.get(Key.COPIED_FLOW_DATA)!);
expect(parsed.nodes).toHaveLength(0);
expect(parsed.edges).toHaveLength(0);
});
});
describe("pasteNodes", () => {
it("creates new nodes with new IDs via incrementNodeCounter", () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 200 },
});
useNodeStore.setState({ nodes: [node], nodeCounter: 5 });
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
const pastedNode = nodes.find((n) => n.id !== "orig");
expect(pastedNode).toBeDefined();
expect(pastedNode!.id).not.toBe("orig");
});
it("offsets pasted node positions by +50 x/y", () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 200 },
});
useNodeStore.setState({ nodes: [node], nodeCounter: 5 });
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
const pastedNode = nodes.find((n) => n.id !== "orig");
expect(pastedNode).toBeDefined();
expect(pastedNode!.position).toEqual({ x: 150, y: 250 });
});
it("preserves internal connections with remapped IDs", () => {
const nodeA = createTestNode("a", {
selected: true,
position: { x: 0, y: 0 },
});
const nodeB = createTestNode("b", {
selected: true,
position: { x: 200, y: 0 },
});
useNodeStore.setState({ nodes: [nodeA, nodeB], nodeCounter: 0 });
useEdgeStore.setState({
edges: [
{
id: "e-ab",
source: "a",
target: "b",
sourceHandle: "output",
targetHandle: "input",
},
],
});
useCopyPasteStore.getState().copySelectedNodes();
useCopyPasteStore.getState().pasteNodes();
const { edges } = useEdgeStore.getState();
const newEdges = edges.filter((e) => e.id !== "e-ab");
expect(newEdges).toHaveLength(1);
const newEdge = newEdges[0];
expect(newEdge.source).not.toBe("a");
expect(newEdge.target).not.toBe("b");
const { nodes } = useNodeStore.getState();
const pastedNodeIDs = nodes
.filter((n) => n.id !== "a" && n.id !== "b")
.map((n) => n.id);
expect(pastedNodeIDs).toContain(newEdge.source);
expect(pastedNodeIDs).toContain(newEdge.target);
});
it("deselects existing nodes and selects pasted ones", () => {
const existingNode = createTestNode("existing", {
selected: true,
position: { x: 0, y: 0 },
});
const nodeToCopy = createTestNode("copy-me", {
selected: true,
position: { x: 100, y: 100 },
});
useNodeStore.setState({
nodes: [existingNode, nodeToCopy],
nodeCounter: 0,
});
useCopyPasteStore.getState().copySelectedNodes();
// Deselect nodeToCopy, keep existingNode selected to verify deselection on paste
useNodeStore.setState({
nodes: [
{ ...existingNode, selected: true },
{ ...nodeToCopy, selected: false },
],
});
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
const originalNodes = nodes.filter(
(n) => n.id === "existing" || n.id === "copy-me",
);
const pastedNodes = nodes.filter(
(n) => n.id !== "existing" && n.id !== "copy-me",
);
originalNodes.forEach((n) => {
expect(n.selected).toBe(false);
});
pastedNodes.forEach((n) => {
expect(n.selected).toBe(true);
});
});
it("does nothing when clipboard is empty", () => {
const node = createTestNode("1", { position: { x: 0, y: 0 } });
useNodeStore.setState({ nodes: [node], nodeCounter: 0 });
useCopyPasteStore.getState().pasteNodes();
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
});
});

View File

@@ -0,0 +1,751 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { MarkerType } from "@xyflow/react";
import { useEdgeStore } from "../stores/edgeStore";
import { useNodeStore } from "../stores/nodeStore";
import { useHistoryStore } from "../stores/historyStore";
import type { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
import type { Link } from "@/app/api/__generated__/models/link";
function makeEdge(overrides: Partial<CustomEdge> & { id: string }): CustomEdge {
return {
type: "custom",
source: "node-a",
target: "node-b",
sourceHandle: "output",
targetHandle: "input",
...overrides,
};
}
function makeExecutionResult(
overrides: Partial<NodeExecutionResult>,
): NodeExecutionResult {
return {
user_id: "user-1",
graph_id: "graph-1",
graph_version: 1,
graph_exec_id: "gexec-1",
node_exec_id: "nexec-1",
node_id: "node-1",
block_id: "block-1",
status: "INCOMPLETE",
input_data: {},
output_data: {},
add_time: new Date(),
queue_time: null,
start_time: null,
end_time: null,
...overrides,
};
}
beforeEach(() => {
useEdgeStore.setState({ edges: [] });
useNodeStore.setState({ nodes: [] });
useHistoryStore.setState({ past: [], future: [] });
});
describe("edgeStore", () => {
describe("setEdges", () => {
it("replaces all edges", () => {
const edges = [
makeEdge({ id: "e1" }),
makeEdge({ id: "e2", source: "node-c" }),
];
useEdgeStore.getState().setEdges(edges);
expect(useEdgeStore.getState().edges).toHaveLength(2);
expect(useEdgeStore.getState().edges[0].id).toBe("e1");
expect(useEdgeStore.getState().edges[1].id).toBe("e2");
});
});
describe("addEdge", () => {
it("adds an edge and auto-generates an ID", () => {
const result = useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(result.id).toBe("n1:out->n2:in");
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(useEdgeStore.getState().edges[0].id).toBe("n1:out->n2:in");
});
it("uses provided ID when given", () => {
const result = useEdgeStore.getState().addEdge({
id: "custom-id",
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(result.id).toBe("custom-id");
});
it("sets type to custom and adds arrow marker", () => {
const result = useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(result.type).toBe("custom");
expect(result.markerEnd).toEqual({
type: MarkerType.ArrowClosed,
strokeWidth: 2,
color: "#555",
});
});
it("rejects duplicate edges without adding", () => {
useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
const duplicate = useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(duplicate.id).toBe("n1:out->n2:in");
expect(pushSpy).not.toHaveBeenCalled();
pushSpy.mockRestore();
});
it("pushes previous state to history store", () => {
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
useEdgeStore.getState().addEdge({
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
});
expect(pushSpy).toHaveBeenCalledWith({
nodes: [],
edges: [],
});
pushSpy.mockRestore();
});
});
describe("removeEdge", () => {
it("removes an edge by ID", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1" }), makeEdge({ id: "e2" })],
});
useEdgeStore.getState().removeEdge("e1");
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(useEdgeStore.getState().edges[0].id).toBe("e2");
});
it("does nothing when removing a non-existent edge", () => {
useEdgeStore.setState({ edges: [makeEdge({ id: "e1" })] });
useEdgeStore.getState().removeEdge("nonexistent");
expect(useEdgeStore.getState().edges).toHaveLength(1);
});
it("pushes previous state to history store", () => {
const existingEdges = [makeEdge({ id: "e1" })];
useEdgeStore.setState({ edges: existingEdges });
const pushSpy = vi.spyOn(useHistoryStore.getState(), "pushState");
useEdgeStore.getState().removeEdge("e1");
expect(pushSpy).toHaveBeenCalledWith({
nodes: [],
edges: existingEdges,
});
pushSpy.mockRestore();
});
});
describe("upsertMany", () => {
it("inserts new edges", () => {
useEdgeStore.setState({ edges: [makeEdge({ id: "e1" })] });
useEdgeStore.getState().upsertMany([makeEdge({ id: "e2" })]);
expect(useEdgeStore.getState().edges).toHaveLength(2);
});
it("updates existing edges by ID", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "old-source" })],
});
useEdgeStore
.getState()
.upsertMany([makeEdge({ id: "e1", source: "new-source" })]);
expect(useEdgeStore.getState().edges).toHaveLength(1);
expect(useEdgeStore.getState().edges[0].source).toBe("new-source");
});
it("handles mixed inserts and updates", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "old" })],
});
useEdgeStore
.getState()
.upsertMany([
makeEdge({ id: "e1", source: "updated" }),
makeEdge({ id: "e2", source: "new" }),
]);
const edges = useEdgeStore.getState().edges;
expect(edges).toHaveLength(2);
expect(edges.find((e) => e.id === "e1")?.source).toBe("updated");
expect(edges.find((e) => e.id === "e2")?.source).toBe("new");
});
});
describe("removeEdgesByHandlePrefix", () => {
it("removes edges targeting a node with matching handle prefix", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", target: "node-b", targetHandle: "input_foo" }),
makeEdge({ id: "e2", target: "node-b", targetHandle: "input_bar" }),
makeEdge({
id: "e3",
target: "node-b",
targetHandle: "other_handle",
}),
makeEdge({ id: "e4", target: "node-c", targetHandle: "input_foo" }),
],
});
useEdgeStore.getState().removeEdgesByHandlePrefix("node-b", "input_");
const edges = useEdgeStore.getState().edges;
expect(edges).toHaveLength(2);
expect(edges.map((e) => e.id).sort()).toEqual(["e3", "e4"]);
});
it("does not remove edges where target does not match nodeId", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-b",
target: "node-c",
targetHandle: "input_x",
}),
],
});
useEdgeStore.getState().removeEdgesByHandlePrefix("node-b", "input_");
expect(useEdgeStore.getState().edges).toHaveLength(1);
});
});
describe("getNodeEdges", () => {
it("returns edges where node is source", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
makeEdge({ id: "e2", source: "node-c", target: "node-d" }),
],
});
const result = useEdgeStore.getState().getNodeEdges("node-a");
expect(result).toHaveLength(1);
expect(result[0].id).toBe("e1");
});
it("returns edges where node is target", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
makeEdge({ id: "e2", source: "node-c", target: "node-d" }),
],
});
const result = useEdgeStore.getState().getNodeEdges("node-b");
expect(result).toHaveLength(1);
expect(result[0].id).toBe("e1");
});
it("returns edges for both source and target", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", source: "node-a", target: "node-b" }),
makeEdge({ id: "e2", source: "node-b", target: "node-c" }),
makeEdge({ id: "e3", source: "node-d", target: "node-e" }),
],
});
const result = useEdgeStore.getState().getNodeEdges("node-b");
expect(result).toHaveLength(2);
expect(result.map((e) => e.id).sort()).toEqual(["e1", "e2"]);
});
it("returns empty array for unconnected node", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "node-a", target: "node-b" })],
});
expect(useEdgeStore.getState().getNodeEdges("node-z")).toHaveLength(0);
});
});
describe("isInputConnected", () => {
it("returns true when target handle is connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
}),
],
});
expect(useEdgeStore.getState().isInputConnected("node-b", "input")).toBe(
true,
);
});
it("returns false when target handle is not connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
}),
],
});
expect(useEdgeStore.getState().isInputConnected("node-b", "other")).toBe(
false,
);
});
it("returns false when node is source not target", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-b",
target: "node-c",
sourceHandle: "output",
targetHandle: "input",
}),
],
});
expect(useEdgeStore.getState().isInputConnected("node-b", "output")).toBe(
false,
);
});
});
describe("isOutputConnected", () => {
it("returns true when source handle is connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-a",
sourceHandle: "output",
}),
],
});
expect(
useEdgeStore.getState().isOutputConnected("node-a", "output"),
).toBe(true);
});
it("returns false when source handle is not connected", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "node-a",
sourceHandle: "output",
}),
],
});
expect(useEdgeStore.getState().isOutputConnected("node-a", "other")).toBe(
false,
);
});
});
describe("getBackendLinks", () => {
it("converts edges to Link format", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
source: "n1",
target: "n2",
sourceHandle: "out",
targetHandle: "in",
data: { isStatic: true },
}),
],
});
const links = useEdgeStore.getState().getBackendLinks();
expect(links).toHaveLength(1);
expect(links[0]).toEqual({
id: "e1",
source_id: "n1",
sink_id: "n2",
source_name: "out",
sink_name: "in",
is_static: true,
});
});
});
describe("addLinks", () => {
it("converts Links to edges and adds them", () => {
const links: Link[] = [
{
id: "link-1",
source_id: "n1",
sink_id: "n2",
source_name: "out",
sink_name: "in",
is_static: false,
},
];
useEdgeStore.getState().addLinks(links);
const edges = useEdgeStore.getState().edges;
expect(edges).toHaveLength(1);
expect(edges[0].source).toBe("n1");
expect(edges[0].target).toBe("n2");
expect(edges[0].sourceHandle).toBe("out");
expect(edges[0].targetHandle).toBe("in");
expect(edges[0].data?.isStatic).toBe(false);
});
it("adds multiple links", () => {
const links: Link[] = [
{
id: "link-1",
source_id: "n1",
sink_id: "n2",
source_name: "out",
sink_name: "in",
},
{
id: "link-2",
source_id: "n3",
sink_id: "n4",
source_name: "result",
sink_name: "value",
},
];
useEdgeStore.getState().addLinks(links);
expect(useEdgeStore.getState().edges).toHaveLength(2);
});
});
describe("getAllHandleIdsOfANode", () => {
it("returns targetHandle values for edges targeting the node", () => {
useEdgeStore.setState({
edges: [
makeEdge({ id: "e1", target: "node-b", targetHandle: "input_a" }),
makeEdge({ id: "e2", target: "node-b", targetHandle: "input_b" }),
makeEdge({ id: "e3", target: "node-c", targetHandle: "input_c" }),
],
});
const handles = useEdgeStore.getState().getAllHandleIdsOfANode("node-b");
expect(handles).toEqual(["input_a", "input_b"]);
});
it("returns empty array when no edges target the node", () => {
useEdgeStore.setState({
edges: [makeEdge({ id: "e1", source: "node-b", target: "node-c" })],
});
expect(useEdgeStore.getState().getAllHandleIdsOfANode("node-b")).toEqual(
[],
);
});
it("returns empty string for edges with no targetHandle", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: undefined,
}),
],
});
expect(useEdgeStore.getState().getAllHandleIdsOfANode("node-b")).toEqual([
"",
]);
});
});
describe("updateEdgeBeads", () => {
it("updates bead counts for edges targeting the node", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "some-value" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(1);
expect(edge.data?.beadDown).toBe(1);
});
it("counts INCOMPLETE status in beadUp but not beadDown", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "INCOMPLETE",
input_data: { input: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(1);
expect(edge.data?.beadDown).toBe(0);
});
it("does not modify edges not targeting the node", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-c",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(0);
expect(edge.data?.beadDown).toBe(0);
});
it("does not update edge when input_data has no matching handle", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { other_handle: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(0);
expect(edge.data?.beadDown).toBe(0);
});
it("accumulates beads across multiple executions", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: { beadUp: 0, beadDown: 0, beadData: new Map() },
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "data1" },
}),
);
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-2",
status: "INCOMPLETE",
input_data: { input: "data2" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(2);
expect(edge.data?.beadDown).toBe(1);
});
it("handles static edges by setting beadUp to beadDown + 1", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
target: "node-b",
targetHandle: "input",
data: {
isStatic: true,
beadUp: 0,
beadDown: 0,
beadData: new Map(),
},
}),
],
});
useEdgeStore.getState().updateEdgeBeads(
"node-b",
makeExecutionResult({
node_exec_id: "exec-1",
status: "COMPLETED",
input_data: { input: "data" },
}),
);
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.beadUp).toBe(2);
expect(edge.data?.beadDown).toBe(1);
});
});
describe("resetEdgeBeads", () => {
it("resets all bead data on all edges", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
data: {
beadUp: 5,
beadDown: 3,
beadData: new Map([["exec-1", "COMPLETED"]]),
},
}),
makeEdge({
id: "e2",
data: {
beadUp: 2,
beadDown: 1,
beadData: new Map([["exec-2", "INCOMPLETE"]]),
},
}),
],
});
useEdgeStore.getState().resetEdgeBeads();
const edges = useEdgeStore.getState().edges;
for (const edge of edges) {
expect(edge.data?.beadUp).toBe(0);
expect(edge.data?.beadDown).toBe(0);
expect(edge.data?.beadData?.size).toBe(0);
}
});
it("preserves other edge data when resetting beads", () => {
useEdgeStore.setState({
edges: [
makeEdge({
id: "e1",
data: {
isStatic: true,
edgeColorClass: "text-red-500",
beadUp: 3,
beadDown: 2,
beadData: new Map(),
},
}),
],
});
useEdgeStore.getState().resetEdgeBeads();
const edge = useEdgeStore.getState().edges[0];
expect(edge.data?.isStatic).toBe(true);
expect(edge.data?.edgeColorClass).toBe("text-red-500");
expect(edge.data?.beadUp).toBe(0);
});
});
});

View File

@@ -0,0 +1,347 @@
import { describe, it, expect, beforeEach } from "vitest";
import { useGraphStore } from "../stores/graphStore";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
function createTestGraphMeta(
overrides: Partial<GraphMeta> & { id: string; name: string },
): GraphMeta {
return {
version: 1,
description: "",
is_active: true,
user_id: "test-user",
created_at: new Date("2024-01-01T00:00:00Z"),
...overrides,
};
}
function resetStore() {
useGraphStore.setState({
graphExecutionStatus: undefined,
isGraphRunning: false,
inputSchema: null,
credentialsInputSchema: null,
outputSchema: null,
availableSubGraphs: [],
});
}
beforeEach(() => {
resetStore();
});
describe("graphStore", () => {
describe("execution status transitions", () => {
it("handles QUEUED -> RUNNING -> COMPLETED transition", () => {
const { setGraphExecutionStatus } = useGraphStore.getState();
setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.QUEUED,
);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.RUNNING,
);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.COMPLETED);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.COMPLETED,
);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("handles QUEUED -> RUNNING -> FAILED transition", () => {
const { setGraphExecutionStatus } = useGraphStore.getState();
setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
setGraphExecutionStatus(AgentExecutionStatus.FAILED);
expect(useGraphStore.getState().graphExecutionStatus).toBe(
AgentExecutionStatus.FAILED,
);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
});
describe("setGraphExecutionStatus auto-sets isGraphRunning", () => {
it("sets isGraphRunning to true for RUNNING", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
});
it("sets isGraphRunning to true for QUEUED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.QUEUED);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
});
it("sets isGraphRunning to false for COMPLETED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.COMPLETED);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for FAILED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.FAILED);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for TERMINATED", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.TERMINATED);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for INCOMPLETE", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.INCOMPLETE);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
it("sets isGraphRunning to false for undefined", () => {
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
useGraphStore.getState().setGraphExecutionStatus(undefined);
expect(useGraphStore.getState().graphExecutionStatus).toBeUndefined();
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
});
describe("setIsGraphRunning", () => {
it("sets isGraphRunning independently of status", () => {
useGraphStore.getState().setIsGraphRunning(true);
expect(useGraphStore.getState().isGraphRunning).toBe(true);
useGraphStore.getState().setIsGraphRunning(false);
expect(useGraphStore.getState().isGraphRunning).toBe(false);
});
});
describe("schema management", () => {
it("sets all three schemas via setGraphSchemas", () => {
const input = { properties: { prompt: { type: "string" } } };
const credentials = { properties: { apiKey: { type: "string" } } };
const output = { properties: { result: { type: "string" } } };
useGraphStore.getState().setGraphSchemas(input, credentials, output);
const state = useGraphStore.getState();
expect(state.inputSchema).toEqual(input);
expect(state.credentialsInputSchema).toEqual(credentials);
expect(state.outputSchema).toEqual(output);
});
it("sets schemas to null", () => {
const input = { properties: { prompt: { type: "string" } } };
useGraphStore.getState().setGraphSchemas(input, null, null);
const state = useGraphStore.getState();
expect(state.inputSchema).toEqual(input);
expect(state.credentialsInputSchema).toBeNull();
expect(state.outputSchema).toBeNull();
});
it("overwrites previous schemas", () => {
const first = { properties: { a: { type: "string" } } };
const second = { properties: { b: { type: "number" } } };
useGraphStore.getState().setGraphSchemas(first, first, first);
useGraphStore.getState().setGraphSchemas(second, null, second);
const state = useGraphStore.getState();
expect(state.inputSchema).toEqual(second);
expect(state.credentialsInputSchema).toBeNull();
expect(state.outputSchema).toEqual(second);
});
});
describe("hasInputs", () => {
it("returns false when inputSchema is null", () => {
expect(useGraphStore.getState().hasInputs()).toBe(false);
});
it("returns false when inputSchema has no properties", () => {
useGraphStore.getState().setGraphSchemas({}, null, null);
expect(useGraphStore.getState().hasInputs()).toBe(false);
});
it("returns false when inputSchema has empty properties", () => {
useGraphStore.getState().setGraphSchemas({ properties: {} }, null, null);
expect(useGraphStore.getState().hasInputs()).toBe(false);
});
it("returns true when inputSchema has properties", () => {
useGraphStore
.getState()
.setGraphSchemas(
{ properties: { prompt: { type: "string" } } },
null,
null,
);
expect(useGraphStore.getState().hasInputs()).toBe(true);
});
});
describe("hasCredentials", () => {
it("returns false when credentialsInputSchema is null", () => {
expect(useGraphStore.getState().hasCredentials()).toBe(false);
});
it("returns false when credentialsInputSchema has empty properties", () => {
useGraphStore.getState().setGraphSchemas(null, { properties: {} }, null);
expect(useGraphStore.getState().hasCredentials()).toBe(false);
});
it("returns true when credentialsInputSchema has properties", () => {
useGraphStore
.getState()
.setGraphSchemas(
null,
{ properties: { apiKey: { type: "string" } } },
null,
);
expect(useGraphStore.getState().hasCredentials()).toBe(true);
});
});
describe("hasOutputs", () => {
it("returns false when outputSchema is null", () => {
expect(useGraphStore.getState().hasOutputs()).toBe(false);
});
it("returns false when outputSchema has empty properties", () => {
useGraphStore.getState().setGraphSchemas(null, null, { properties: {} });
expect(useGraphStore.getState().hasOutputs()).toBe(false);
});
it("returns true when outputSchema has properties", () => {
useGraphStore.getState().setGraphSchemas(null, null, {
properties: { result: { type: "string" } },
});
expect(useGraphStore.getState().hasOutputs()).toBe(true);
});
});
describe("reset", () => {
it("clears execution status and schemas but preserves outputSchema and availableSubGraphs", () => {
const subGraphs: GraphMeta[] = [
createTestGraphMeta({
id: "sub-1",
name: "Sub Graph",
description: "A sub graph",
}),
];
useGraphStore
.getState()
.setGraphExecutionStatus(AgentExecutionStatus.RUNNING);
useGraphStore
.getState()
.setGraphSchemas(
{ properties: { a: {} } },
{ properties: { b: {} } },
{ properties: { c: {} } },
);
useGraphStore.getState().setAvailableSubGraphs(subGraphs);
useGraphStore.getState().reset();
const state = useGraphStore.getState();
expect(state.graphExecutionStatus).toBeUndefined();
expect(state.isGraphRunning).toBe(false);
expect(state.inputSchema).toBeNull();
expect(state.credentialsInputSchema).toBeNull();
// reset does not clear outputSchema or availableSubGraphs
expect(state.outputSchema).toEqual({ properties: { c: {} } });
expect(state.availableSubGraphs).toEqual(subGraphs);
});
it("is idempotent on fresh state", () => {
useGraphStore.getState().reset();
const state = useGraphStore.getState();
expect(state.graphExecutionStatus).toBeUndefined();
expect(state.isGraphRunning).toBe(false);
expect(state.inputSchema).toBeNull();
expect(state.credentialsInputSchema).toBeNull();
});
});
describe("setAvailableSubGraphs", () => {
it("sets sub-graphs list", () => {
const graphs: GraphMeta[] = [
createTestGraphMeta({
id: "graph-1",
name: "Graph One",
description: "First graph",
}),
createTestGraphMeta({
id: "graph-2",
version: 2,
name: "Graph Two",
description: "Second graph",
}),
];
useGraphStore.getState().setAvailableSubGraphs(graphs);
expect(useGraphStore.getState().availableSubGraphs).toEqual(graphs);
});
it("replaces previous sub-graphs", () => {
const first: GraphMeta[] = [createTestGraphMeta({ id: "a", name: "A" })];
const second: GraphMeta[] = [
createTestGraphMeta({ id: "b", name: "B" }),
createTestGraphMeta({ id: "c", name: "C" }),
];
useGraphStore.getState().setAvailableSubGraphs(first);
expect(useGraphStore.getState().availableSubGraphs).toHaveLength(1);
useGraphStore.getState().setAvailableSubGraphs(second);
expect(useGraphStore.getState().availableSubGraphs).toHaveLength(2);
expect(useGraphStore.getState().availableSubGraphs).toEqual(second);
});
it("can set empty sub-graphs list", () => {
useGraphStore
.getState()
.setAvailableSubGraphs([createTestGraphMeta({ id: "x", name: "X" })]);
useGraphStore.getState().setAvailableSubGraphs([]);
expect(useGraphStore.getState().availableSubGraphs).toEqual([]);
});
});
});

View File

@@ -0,0 +1,407 @@
import { describe, it, expect, beforeEach } from "vitest";
import { useHistoryStore } from "../stores/historyStore";
import { useNodeStore } from "../stores/nodeStore";
import { useEdgeStore } from "../stores/edgeStore";
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
function createTestNode(
id: string,
overrides: Partial<CustomNode> = {},
): CustomNode {
return {
id,
type: "custom" as const,
position: { x: 0, y: 0 },
data: {
hardcodedValues: {},
title: `Node ${id}`,
description: "",
inputSchema: {},
outputSchema: {},
uiType: "STANDARD" as never,
block_id: `block-${id}`,
costs: [],
categories: [],
},
...overrides,
} as CustomNode;
}
function createTestEdge(
id: string,
source: string,
target: string,
): CustomEdge {
return {
id,
source,
target,
type: "custom" as const,
} as CustomEdge;
}
async function flushMicrotasks() {
await new Promise<void>((resolve) => queueMicrotask(resolve));
}
beforeEach(() => {
useHistoryStore.getState().clear();
useNodeStore.setState({ nodes: [] });
useEdgeStore.setState({ edges: [] });
});
describe("historyStore", () => {
describe("undo/redo single action", () => {
it("undoes a single pushed state", async () => {
const node = createTestNode("1");
// Initialize history with node present as baseline
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().initializeHistory();
// Simulate a change: clear nodes
useNodeStore.setState({ nodes: [] });
// Undo should restore to [node]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node]);
expect(useHistoryStore.getState().future).toHaveLength(1);
expect(useHistoryStore.getState().future[0].nodes).toEqual([]);
});
it("redoes after undo", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().initializeHistory();
// Change: clear nodes
useNodeStore.setState({ nodes: [] });
// Undo → back to [node]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node]);
// Redo → back to []
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([]);
});
});
describe("undo/redo multiple actions", () => {
it("undoes through multiple states in order", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
// Initialize with [node1] as baseline
useNodeStore.setState({ nodes: [node1] });
useHistoryStore.getState().initializeHistory();
// Second change: add node2, push pre-change state
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
// Third change: add node3, push pre-change state
useNodeStore.setState({ nodes: [node1, node2, node3] });
useHistoryStore
.getState()
.pushState({ nodes: [node1, node2], edges: [] });
await flushMicrotasks();
// Undo 1: back to [node1, node2]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
// Undo 2: back to [node1]
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1]);
});
});
describe("undo past empty history", () => {
it("does nothing when there is no history to undo", () => {
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([]);
expect(useEdgeStore.getState().edges).toEqual([]);
expect(useHistoryStore.getState().past).toHaveLength(1);
});
it("does nothing when current state equals last past entry", () => {
expect(useHistoryStore.getState().canUndo()).toBe(false);
useHistoryStore.getState().undo();
expect(useHistoryStore.getState().past).toHaveLength(1);
expect(useHistoryStore.getState().future).toHaveLength(0);
});
});
describe("state consistency: undo after node add restores previous, redo restores added", () => {
it("undo removes added node, redo restores it", async () => {
const node = createTestNode("added");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([]);
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([node]);
});
});
describe("history limits", () => {
it("does not grow past MAX_HISTORY (50)", async () => {
for (let i = 0; i < 60; i++) {
const node = createTestNode(`node-${i}`);
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({
nodes: [createTestNode(`node-${i - 1}`)],
edges: [],
});
await flushMicrotasks();
}
expect(useHistoryStore.getState().past.length).toBeLessThanOrEqual(50);
});
});
describe("edge cases", () => {
it("redo does nothing when future is empty", () => {
const nodesBefore = useNodeStore.getState().nodes;
const edgesBefore = useEdgeStore.getState().edges;
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual(nodesBefore);
expect(useEdgeStore.getState().edges).toEqual(edgesBefore);
});
it("interleaved undo/redo sequence", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
useNodeStore.setState({ nodes: [node1] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
useNodeStore.setState({ nodes: [node1, node2, node3] });
useHistoryStore.getState().pushState({
nodes: [node1, node2],
edges: [],
});
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1]);
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2]);
useHistoryStore.getState().undo();
expect(useNodeStore.getState().nodes).toEqual([node1]);
useHistoryStore.getState().redo();
useHistoryStore.getState().redo();
expect(useNodeStore.getState().nodes).toEqual([node1, node2, node3]);
});
});
describe("canUndo / canRedo", () => {
it("canUndo is false on fresh store", () => {
expect(useHistoryStore.getState().canUndo()).toBe(false);
});
it("canUndo is true when current state differs from last past entry", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().canUndo()).toBe(true);
});
it("canRedo is false on fresh store", () => {
expect(useHistoryStore.getState().canRedo()).toBe(false);
});
it("canRedo is true after undo", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useHistoryStore.getState().canRedo()).toBe(true);
});
it("canRedo becomes false after redo exhausts future", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
useHistoryStore.getState().redo();
expect(useHistoryStore.getState().canRedo()).toBe(false);
});
});
describe("pushState deduplication", () => {
it("does not push a state identical to the last past entry", async () => {
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().past).toHaveLength(1);
});
it("does not push if state matches current node/edge store state", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useEdgeStore.setState({ edges: [] });
useHistoryStore.getState().pushState({ nodes: [node], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().past).toHaveLength(1);
});
});
describe("initializeHistory", () => {
it("resets history with current node/edge store state", async () => {
const node = createTestNode("1");
const edge = createTestEdge("e1", "1", "2");
useNodeStore.setState({ nodes: [node] });
useEdgeStore.setState({ edges: [edge] });
useNodeStore.setState({ nodes: [node, createTestNode("2")] });
useHistoryStore.getState().pushState({ nodes: [node], edges: [edge] });
await flushMicrotasks();
useHistoryStore.getState().initializeHistory();
const { past, future } = useHistoryStore.getState();
expect(past).toHaveLength(1);
expect(past[0].nodes).toEqual(useNodeStore.getState().nodes);
expect(past[0].edges).toEqual(useEdgeStore.getState().edges);
expect(future).toHaveLength(0);
});
});
describe("clear", () => {
it("resets to empty initial state", async () => {
const node = createTestNode("1");
useNodeStore.setState({ nodes: [node] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().clear();
const { past, future } = useHistoryStore.getState();
expect(past).toEqual([{ nodes: [], edges: [] }]);
expect(future).toEqual([]);
});
});
describe("microtask batching", () => {
it("only commits the first state when multiple pushState calls happen in the same tick", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
useNodeStore.setState({ nodes: [node1, node2, node3] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
useHistoryStore.getState().pushState({ nodes: [node2], edges: [] });
useHistoryStore
.getState()
.pushState({ nodes: [node1, node2], edges: [] });
await flushMicrotasks();
const { past } = useHistoryStore.getState();
expect(past).toHaveLength(2);
expect(past[1].nodes).toEqual([node1]);
});
it("commits separately when pushState calls are in different ticks", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().pushState({ nodes: [node2], edges: [] });
await flushMicrotasks();
const { past } = useHistoryStore.getState();
expect(past).toHaveLength(3);
expect(past[1].nodes).toEqual([node1]);
expect(past[2].nodes).toEqual([node2]);
});
});
describe("edges in undo/redo", () => {
it("restores edges on undo and redo", async () => {
const edge = createTestEdge("e1", "1", "2");
useEdgeStore.setState({ edges: [edge] });
useHistoryStore.getState().pushState({ nodes: [], edges: [] });
await flushMicrotasks();
useHistoryStore.getState().undo();
expect(useEdgeStore.getState().edges).toEqual([]);
useHistoryStore.getState().redo();
expect(useEdgeStore.getState().edges).toEqual([edge]);
});
});
describe("pushState clears future", () => {
it("clears future when a new state is pushed after undo", async () => {
const node1 = createTestNode("1");
const node2 = createTestNode("2");
const node3 = createTestNode("3");
// Initialize empty
useHistoryStore.getState().initializeHistory();
// First change: set [node1]
useNodeStore.setState({ nodes: [node1] });
// Second change: set [node1, node2], push pre-change [node1]
useNodeStore.setState({ nodes: [node1, node2] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
// Undo: back to [node1]
useHistoryStore.getState().undo();
expect(useHistoryStore.getState().future).toHaveLength(1);
// New diverging change: add node3 instead of node2
useNodeStore.setState({ nodes: [node1, node3] });
useHistoryStore.getState().pushState({ nodes: [node1], edges: [] });
await flushMicrotasks();
expect(useHistoryStore.getState().future).toHaveLength(0);
});
});
});

View File

@@ -0,0 +1,791 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { useNodeStore } from "../stores/nodeStore";
import { useHistoryStore } from "../stores/historyStore";
import { useEdgeStore } from "../stores/edgeStore";
import { BlockUIType } from "../components/types";
import type { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import type { CustomNodeData } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
function createTestNode(overrides: {
id: string;
position?: { x: number; y: number };
data?: Partial<CustomNodeData>;
}): CustomNode {
const defaults: CustomNodeData = {
hardcodedValues: {},
title: "Test Block",
description: "A test block",
inputSchema: {},
outputSchema: {},
uiType: BlockUIType.STANDARD,
block_id: "test-block-id",
costs: [],
categories: [],
};
return {
id: overrides.id,
type: "custom",
position: overrides.position ?? { x: 0, y: 0 },
data: { ...defaults, ...overrides.data },
};
}
function createExecutionResult(
overrides: Partial<NodeExecutionResult> = {},
): NodeExecutionResult {
return {
node_exec_id: overrides.node_exec_id ?? "exec-1",
node_id: overrides.node_id ?? "1",
graph_exec_id: overrides.graph_exec_id ?? "graph-exec-1",
graph_id: overrides.graph_id ?? "graph-1",
graph_version: overrides.graph_version ?? 1,
user_id: overrides.user_id ?? "test-user",
block_id: overrides.block_id ?? "block-1",
status: overrides.status ?? "COMPLETED",
input_data: overrides.input_data ?? { input_key: "input_value" },
output_data: overrides.output_data ?? { output_key: ["output_value"] },
add_time: overrides.add_time ?? new Date("2024-01-01T00:00:00Z"),
queue_time: overrides.queue_time ?? new Date("2024-01-01T00:00:00Z"),
start_time: overrides.start_time ?? new Date("2024-01-01T00:00:01Z"),
end_time: overrides.end_time ?? new Date("2024-01-01T00:00:02Z"),
};
}
function resetStores() {
useNodeStore.setState({
nodes: [],
nodeCounter: 0,
nodeAdvancedStates: {},
latestNodeInputData: {},
latestNodeOutputData: {},
accumulatedNodeInputData: {},
accumulatedNodeOutputData: {},
nodesInResolutionMode: new Set(),
brokenEdgeIDs: new Map(),
nodeResolutionData: new Map(),
});
useEdgeStore.setState({ edges: [] });
useHistoryStore.setState({ past: [], future: [] });
}
describe("nodeStore", () => {
beforeEach(() => {
resetStores();
vi.restoreAllMocks();
});
describe("node lifecycle", () => {
it("starts with empty nodes", () => {
const { nodes } = useNodeStore.getState();
expect(nodes).toEqual([]);
});
it("adds a single node with addNode", () => {
const node = createTestNode({ id: "1" });
useNodeStore.getState().addNode(node);
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
it("sets nodes with setNodes, replacing existing ones", () => {
const node1 = createTestNode({ id: "1" });
const node2 = createTestNode({ id: "2" });
useNodeStore.getState().addNode(node1);
useNodeStore.getState().setNodes([node2]);
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("2");
});
it("removes nodes via onNodesChange", () => {
const node = createTestNode({ id: "1" });
useNodeStore.getState().setNodes([node]);
useNodeStore.getState().onNodesChange([{ type: "remove", id: "1" }]);
expect(useNodeStore.getState().nodes).toHaveLength(0);
});
it("updates node data with updateNodeData", () => {
const node = createTestNode({ id: "1" });
useNodeStore.getState().addNode(node);
useNodeStore.getState().updateNodeData("1", { title: "Updated Title" });
const updated = useNodeStore.getState().nodes[0];
expect(updated.data.title).toBe("Updated Title");
expect(updated.data.block_id).toBe("test-block-id");
});
it("updateNodeData does not affect other nodes", () => {
const node1 = createTestNode({ id: "1" });
const node2 = createTestNode({
id: "2",
data: { title: "Node 2" },
});
useNodeStore.getState().setNodes([node1, node2]);
useNodeStore.getState().updateNodeData("1", { title: "Changed" });
expect(useNodeStore.getState().nodes[1].data.title).toBe("Node 2");
});
});
describe("bulk operations", () => {
it("adds multiple nodes with addNodes", () => {
const nodes = [
createTestNode({ id: "1" }),
createTestNode({ id: "2" }),
createTestNode({ id: "3" }),
];
useNodeStore.getState().addNodes(nodes);
expect(useNodeStore.getState().nodes).toHaveLength(3);
});
it("removes multiple nodes via onNodesChange", () => {
const nodes = [
createTestNode({ id: "1" }),
createTestNode({ id: "2" }),
createTestNode({ id: "3" }),
];
useNodeStore.getState().setNodes(nodes);
useNodeStore.getState().onNodesChange([
{ type: "remove", id: "1" },
{ type: "remove", id: "3" },
]);
const remaining = useNodeStore.getState().nodes;
expect(remaining).toHaveLength(1);
expect(remaining[0].id).toBe("2");
});
});
describe("nodeCounter", () => {
it("starts at zero", () => {
expect(useNodeStore.getState().nodeCounter).toBe(0);
});
it("increments the counter", () => {
useNodeStore.getState().incrementNodeCounter();
expect(useNodeStore.getState().nodeCounter).toBe(1);
useNodeStore.getState().incrementNodeCounter();
expect(useNodeStore.getState().nodeCounter).toBe(2);
});
it("sets the counter to a specific value", () => {
useNodeStore.getState().setNodeCounter(42);
expect(useNodeStore.getState().nodeCounter).toBe(42);
});
});
describe("advanced states", () => {
it("defaults to false for unknown node IDs", () => {
expect(useNodeStore.getState().getShowAdvanced("unknown")).toBe(false);
});
it("toggles advanced state", () => {
useNodeStore.getState().toggleAdvanced("node-1");
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(true);
useNodeStore.getState().toggleAdvanced("node-1");
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(false);
});
it("sets advanced state explicitly", () => {
useNodeStore.getState().setShowAdvanced("node-1", true);
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(true);
useNodeStore.getState().setShowAdvanced("node-1", false);
expect(useNodeStore.getState().getShowAdvanced("node-1")).toBe(false);
});
});
describe("convertCustomNodeToBackendNode", () => {
it("converts a node with minimal data", () => {
const node = createTestNode({
id: "42",
position: { x: 100, y: 200 },
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.id).toBe("42");
expect(backend.block_id).toBe("test-block-id");
expect(backend.input_default).toEqual({});
expect(backend.metadata).toEqual({ position: { x: 100, y: 200 } });
});
it("includes customized_name when present in metadata", () => {
const node = createTestNode({
id: "1",
data: {
metadata: { customized_name: "My Custom Name" },
},
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.metadata).toHaveProperty(
"customized_name",
"My Custom Name",
);
});
it("includes credentials_optional when present in metadata", () => {
const node = createTestNode({
id: "1",
data: {
metadata: { credentials_optional: true },
},
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.metadata).toHaveProperty("credentials_optional", true);
});
it("prunes empty values from hardcodedValues", () => {
const node = createTestNode({
id: "1",
data: {
hardcodedValues: { filled: "value", empty: "" },
},
});
const backend = useNodeStore
.getState()
.convertCustomNodeToBackendNode(node);
expect(backend.input_default).toEqual({ filled: "value" });
expect(backend.input_default).not.toHaveProperty("empty");
});
});
describe("getBackendNodes", () => {
it("converts all nodes to backend format", () => {
useNodeStore
.getState()
.setNodes([
createTestNode({ id: "1", position: { x: 0, y: 0 } }),
createTestNode({ id: "2", position: { x: 100, y: 100 } }),
]);
const backendNodes = useNodeStore.getState().getBackendNodes();
expect(backendNodes).toHaveLength(2);
expect(backendNodes[0].id).toBe("1");
expect(backendNodes[1].id).toBe("2");
});
});
describe("node status", () => {
it("returns undefined for a node with no status", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().getNodeStatus("1")).toBeUndefined();
});
it("updates node status", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
expect(useNodeStore.getState().getNodeStatus("1")).toBe("RUNNING");
useNodeStore.getState().updateNodeStatus("1", "COMPLETED");
expect(useNodeStore.getState().getNodeStatus("1")).toBe("COMPLETED");
});
it("cleans all node statuses", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
useNodeStore.getState().updateNodeStatus("2", "COMPLETED");
useNodeStore.getState().cleanNodesStatuses();
expect(useNodeStore.getState().getNodeStatus("1")).toBeUndefined();
expect(useNodeStore.getState().getNodeStatus("2")).toBeUndefined();
});
it("updating status for non-existent node does not crash", () => {
useNodeStore.getState().updateNodeStatus("nonexistent", "RUNNING");
expect(
useNodeStore.getState().getNodeStatus("nonexistent"),
).toBeUndefined();
});
});
describe("execution result tracking", () => {
it("returns empty array for node with no results", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().getNodeExecutionResults("1")).toEqual([]);
});
it("tracks a single execution result", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
const result = createExecutionResult({ node_id: "1" });
useNodeStore.getState().updateNodeExecutionResult("1", result);
const results = useNodeStore.getState().getNodeExecutionResults("1");
expect(results).toHaveLength(1);
expect(results[0].node_exec_id).toBe("exec-1");
});
it("accumulates multiple execution results", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "val1" },
output_data: { key: ["out1"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-2",
input_data: { key: "val2" },
output_data: { key: ["out2"] },
}),
);
expect(useNodeStore.getState().getNodeExecutionResults("1")).toHaveLength(
2,
);
});
it("updates latest input/output data", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "first" },
output_data: { key: ["first_out"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-2",
input_data: { key: "second" },
output_data: { key: ["second_out"] },
}),
);
expect(useNodeStore.getState().getLatestNodeInputData("1")).toEqual({
key: "second",
});
expect(useNodeStore.getState().getLatestNodeOutputData("1")).toEqual({
key: ["second_out"],
});
});
it("accumulates input/output data across results", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "val1" },
output_data: { key: ["out1"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-2",
input_data: { key: "val2" },
output_data: { key: ["out2"] },
}),
);
const accInput = useNodeStore.getState().getAccumulatedNodeInputData("1");
expect(accInput.key).toEqual(["val1", "val2"]);
const accOutput = useNodeStore
.getState()
.getAccumulatedNodeOutputData("1");
expect(accOutput.key).toEqual(["out1", "out2"]);
});
it("deduplicates execution results by node_exec_id", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "original" },
output_data: { key: ["original_out"] },
}),
);
useNodeStore.getState().updateNodeExecutionResult(
"1",
createExecutionResult({
node_exec_id: "exec-1",
input_data: { key: "updated" },
output_data: { key: ["updated_out"] },
}),
);
const results = useNodeStore.getState().getNodeExecutionResults("1");
expect(results).toHaveLength(1);
expect(results[0].input_data).toEqual({ key: "updated" });
});
it("returns the latest execution result", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore
.getState()
.updateNodeExecutionResult(
"1",
createExecutionResult({ node_exec_id: "exec-1" }),
);
useNodeStore
.getState()
.updateNodeExecutionResult(
"1",
createExecutionResult({ node_exec_id: "exec-2" }),
);
const latest = useNodeStore.getState().getLatestNodeExecutionResult("1");
expect(latest?.node_exec_id).toBe("exec-2");
});
it("returns undefined for latest result on unknown node", () => {
expect(
useNodeStore.getState().getLatestNodeExecutionResult("unknown"),
).toBeUndefined();
});
it("clears all execution results", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore
.getState()
.updateNodeExecutionResult(
"1",
createExecutionResult({ node_exec_id: "exec-1" }),
);
useNodeStore
.getState()
.updateNodeExecutionResult(
"2",
createExecutionResult({ node_exec_id: "exec-2" }),
);
useNodeStore.getState().clearAllNodeExecutionResults();
expect(useNodeStore.getState().getNodeExecutionResults("1")).toEqual([]);
expect(useNodeStore.getState().getNodeExecutionResults("2")).toEqual([]);
expect(
useNodeStore.getState().getLatestNodeInputData("1"),
).toBeUndefined();
expect(
useNodeStore.getState().getLatestNodeOutputData("1"),
).toBeUndefined();
expect(useNodeStore.getState().getAccumulatedNodeInputData("1")).toEqual(
{},
);
expect(useNodeStore.getState().getAccumulatedNodeOutputData("1")).toEqual(
{},
);
});
it("returns empty object for accumulated data on unknown node", () => {
expect(
useNodeStore.getState().getAccumulatedNodeInputData("unknown"),
).toEqual({});
expect(
useNodeStore.getState().getAccumulatedNodeOutputData("unknown"),
).toEqual({});
});
});
describe("getNodeBlockUIType", () => {
it("returns the node UI type", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
uiType: BlockUIType.INPUT,
},
}),
);
expect(useNodeStore.getState().getNodeBlockUIType("1")).toBe(
BlockUIType.INPUT,
);
});
it("defaults to STANDARD for unknown node IDs", () => {
expect(useNodeStore.getState().getNodeBlockUIType("unknown")).toBe(
BlockUIType.STANDARD,
);
});
});
describe("hasWebhookNodes", () => {
it("returns false when there are no webhook nodes", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().hasWebhookNodes()).toBe(false);
});
it("returns true when a WEBHOOK node exists", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
uiType: BlockUIType.WEBHOOK,
},
}),
);
expect(useNodeStore.getState().hasWebhookNodes()).toBe(true);
});
it("returns true when a WEBHOOK_MANUAL node exists", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
uiType: BlockUIType.WEBHOOK_MANUAL,
},
}),
);
expect(useNodeStore.getState().hasWebhookNodes()).toBe(true);
});
});
describe("node errors", () => {
it("returns undefined for a node with no errors", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
});
it("sets and retrieves node errors", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
const errors = { field1: "required", field2: "invalid" };
useNodeStore.getState().updateNodeErrors("1", errors);
expect(useNodeStore.getState().getNodeErrors("1")).toEqual(errors);
});
it("clears errors for a specific node", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore.getState().updateNodeErrors("1", { f: "err" });
useNodeStore.getState().updateNodeErrors("2", { g: "err2" });
useNodeStore.getState().clearNodeErrors("1");
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
expect(useNodeStore.getState().getNodeErrors("2")).toEqual({ g: "err2" });
});
it("clears all node errors", () => {
useNodeStore
.getState()
.setNodes([createTestNode({ id: "1" }), createTestNode({ id: "2" })]);
useNodeStore.getState().updateNodeErrors("1", { a: "err1" });
useNodeStore.getState().updateNodeErrors("2", { b: "err2" });
useNodeStore.getState().clearAllNodeErrors();
expect(useNodeStore.getState().getNodeErrors("1")).toBeUndefined();
expect(useNodeStore.getState().getNodeErrors("2")).toBeUndefined();
});
it("sets errors by backend ID matching node id", () => {
useNodeStore.getState().addNode(createTestNode({ id: "backend-1" }));
useNodeStore
.getState()
.setNodeErrorsForBackendId("backend-1", { x: "error" });
expect(useNodeStore.getState().getNodeErrors("backend-1")).toEqual({
x: "error",
});
});
});
describe("getHardCodedValues", () => {
it("returns hardcoded values for a node", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
hardcodedValues: { key: "value" },
},
}),
);
expect(useNodeStore.getState().getHardCodedValues("1")).toEqual({
key: "value",
});
});
it("returns empty object for unknown node", () => {
expect(useNodeStore.getState().getHardCodedValues("unknown")).toEqual({});
});
});
describe("credentials optional", () => {
it("sets credentials_optional in node metadata", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore.getState().setCredentialsOptional("1", true);
const node = useNodeStore.getState().nodes[0];
expect(node.data.metadata?.credentials_optional).toBe(true);
});
});
describe("resolution mode", () => {
it("defaults to not in resolution mode", () => {
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
});
it("enters and exits resolution mode", () => {
useNodeStore.getState().setNodeResolutionMode("1", true);
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(true);
useNodeStore.getState().setNodeResolutionMode("1", false);
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
});
it("tracks broken edge IDs", () => {
useNodeStore.getState().setBrokenEdgeIDs("node-1", ["edge-1", "edge-2"]);
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(true);
expect(useNodeStore.getState().isEdgeBroken("edge-2")).toBe(true);
expect(useNodeStore.getState().isEdgeBroken("edge-3")).toBe(false);
});
it("removes individual broken edge IDs", () => {
useNodeStore.getState().setBrokenEdgeIDs("node-1", ["edge-1", "edge-2"]);
useNodeStore.getState().removeBrokenEdgeID("node-1", "edge-1");
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
expect(useNodeStore.getState().isEdgeBroken("edge-2")).toBe(true);
});
it("clears all resolution state", () => {
useNodeStore.getState().setNodeResolutionMode("1", true);
useNodeStore.getState().setBrokenEdgeIDs("1", ["edge-1"]);
useNodeStore.getState().clearResolutionState();
expect(useNodeStore.getState().isNodeInResolutionMode("1")).toBe(false);
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
});
it("cleans up broken edges when exiting resolution mode", () => {
useNodeStore.getState().setNodeResolutionMode("1", true);
useNodeStore.getState().setBrokenEdgeIDs("1", ["edge-1"]);
useNodeStore.getState().setNodeResolutionMode("1", false);
expect(useNodeStore.getState().isEdgeBroken("edge-1")).toBe(false);
});
});
describe("edge cases", () => {
it("handles updating data on a non-existent node gracefully", () => {
useNodeStore
.getState()
.updateNodeData("nonexistent", { title: "New Title" });
expect(useNodeStore.getState().nodes).toHaveLength(0);
});
it("handles removing a non-existent node gracefully", () => {
useNodeStore.getState().addNode(createTestNode({ id: "1" }));
useNodeStore
.getState()
.onNodesChange([{ type: "remove", id: "nonexistent" }]);
expect(useNodeStore.getState().nodes).toHaveLength(1);
});
it("handles duplicate node IDs in addNodes", () => {
useNodeStore.getState().addNodes([
createTestNode({
id: "1",
data: { title: "First" },
}),
createTestNode({
id: "1",
data: { title: "Second" },
}),
]);
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
expect(nodes[0].data.title).toBe("First");
expect(nodes[1].data.title).toBe("Second");
});
it("updating node status mid-execution preserves other data", () => {
useNodeStore.getState().addNode(
createTestNode({
id: "1",
data: {
title: "My Node",
hardcodedValues: { key: "val" },
},
}),
);
useNodeStore.getState().updateNodeStatus("1", "RUNNING");
const node = useNodeStore.getState().nodes[0];
expect(node.data.status).toBe("RUNNING");
expect(node.data.title).toBe("My Node");
expect(node.data.hardcodedValues).toEqual({ key: "val" });
});
it("execution result for non-existent node does not add it", () => {
useNodeStore
.getState()
.updateNodeExecutionResult(
"nonexistent",
createExecutionResult({ node_exec_id: "exec-1" }),
);
expect(useNodeStore.getState().nodes).toHaveLength(0);
expect(
useNodeStore.getState().getNodeExecutionResults("nonexistent"),
).toEqual([]);
});
it("getBackendNodes returns empty array when no nodes exist", () => {
expect(useNodeStore.getState().getBackendNodes()).toEqual([]);
});
});
});

View File

@@ -0,0 +1,567 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { CustomNode } from "../components/FlowEditor/nodes/CustomNode/CustomNode";
import { BlockUIType } from "../components/types";
// ---- Mocks ----
const mockGetViewport = vi.fn(() => ({ x: 0, y: 0, zoom: 1 }));
vi.mock("@xyflow/react", async () => {
const actual = await vi.importActual("@xyflow/react");
return {
...actual,
useReactFlow: vi.fn(() => ({
getViewport: mockGetViewport,
})),
};
});
const mockToast = vi.fn();
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: vi.fn(() => ({ toast: mockToast })),
}));
let uuidCounter = 0;
vi.mock("uuid", () => ({
v4: vi.fn(() => `new-uuid-${++uuidCounter}`),
}));
// Mock navigator.clipboard
const mockWriteText = vi.fn(() => Promise.resolve());
const mockReadText = vi.fn(() => Promise.resolve(""));
Object.defineProperty(navigator, "clipboard", {
value: {
writeText: mockWriteText,
readText: mockReadText,
},
writable: true,
configurable: true,
});
// Mock window.innerWidth / innerHeight for viewport centering calculations
Object.defineProperty(window, "innerWidth", { value: 1000, writable: true });
Object.defineProperty(window, "innerHeight", { value: 800, writable: true });
import { useCopyPaste } from "../components/FlowEditor/Flow/useCopyPaste";
import { useNodeStore } from "../stores/nodeStore";
import { useEdgeStore } from "../stores/edgeStore";
import { useHistoryStore } from "../stores/historyStore";
import { CustomEdge } from "../components/FlowEditor/edges/CustomEdge";
const CLIPBOARD_PREFIX = "autogpt-flow-data:";
function createTestNode(
id: string,
overrides: Partial<CustomNode> = {},
): CustomNode {
return {
id,
type: "custom",
position: overrides.position ?? { x: 100, y: 200 },
selected: overrides.selected,
data: {
hardcodedValues: {},
title: `Node ${id}`,
description: "test node",
inputSchema: {},
outputSchema: {},
uiType: BlockUIType.STANDARD,
block_id: `block-${id}`,
costs: [],
categories: [],
...overrides.data,
},
} as CustomNode;
}
function createTestEdge(
id: string,
source: string,
target: string,
sourceHandle = "out",
targetHandle = "in",
): CustomEdge {
return {
id,
source,
target,
sourceHandle,
targetHandle,
} as CustomEdge;
}
function makeCopyEvent(): KeyboardEvent {
return new KeyboardEvent("keydown", {
key: "c",
ctrlKey: true,
bubbles: true,
});
}
function makePasteEvent(): KeyboardEvent {
return new KeyboardEvent("keydown", {
key: "v",
ctrlKey: true,
bubbles: true,
});
}
function clipboardPayload(nodes: CustomNode[], edges: CustomEdge[]): string {
return `${CLIPBOARD_PREFIX}${JSON.stringify({ nodes, edges })}`;
}
describe("useCopyPaste", () => {
beforeEach(() => {
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
useEdgeStore.setState({ edges: [] });
useHistoryStore.getState().clear();
mockWriteText.mockClear();
mockReadText.mockClear();
mockToast.mockClear();
mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 });
uuidCounter = 0;
// Ensure no input element is focused
if (document.activeElement && document.activeElement !== document.body) {
(document.activeElement as HTMLElement).blur();
}
});
describe("copy (Ctrl+C)", () => {
it("copies a single selected node to clipboard with prefix", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const written = (mockWriteText.mock.calls as string[][])[0][0];
expect(written.startsWith(CLIPBOARD_PREFIX)).toBe(true);
const parsed = JSON.parse(written.slice(CLIPBOARD_PREFIX.length));
expect(parsed.nodes).toHaveLength(1);
expect(parsed.nodes[0].id).toBe("1");
expect(parsed.edges).toHaveLength(0);
});
it("shows a success toast after copying", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Copied successfully",
}),
);
});
});
it("copies multiple connected nodes and preserves internal edges", async () => {
const nodeA = createTestNode("a", { selected: true });
const nodeB = createTestNode("b", { selected: true });
const nodeC = createTestNode("c", { selected: false });
useNodeStore.setState({ nodes: [nodeA, nodeB, nodeC] });
useEdgeStore.setState({
edges: [
createTestEdge("e-ab", "a", "b"),
createTestEdge("e-bc", "b", "c"),
],
});
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const parsed = JSON.parse(
(mockWriteText.mock.calls as string[][])[0][0].slice(
CLIPBOARD_PREFIX.length,
),
);
expect(parsed.nodes).toHaveLength(2);
expect(parsed.edges).toHaveLength(1);
expect(parsed.edges[0].id).toBe("e-ab");
});
it("drops external edges where one endpoint is not selected", async () => {
const nodeA = createTestNode("a", { selected: true });
const nodeB = createTestNode("b", { selected: false });
useNodeStore.setState({ nodes: [nodeA, nodeB] });
useEdgeStore.setState({
edges: [createTestEdge("e-ab", "a", "b")],
});
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const parsed = JSON.parse(
(mockWriteText.mock.calls as string[][])[0][0].slice(
CLIPBOARD_PREFIX.length,
),
);
expect(parsed.nodes).toHaveLength(1);
expect(parsed.edges).toHaveLength(0);
});
it("copies nothing when no nodes are selected", async () => {
const node = createTestNode("1", { selected: false });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
const parsed = JSON.parse(
(mockWriteText.mock.calls as string[][])[0][0].slice(
CLIPBOARD_PREFIX.length,
),
);
expect(parsed.nodes).toHaveLength(0);
expect(parsed.edges).toHaveLength(0);
});
});
describe("paste (Ctrl+V)", () => {
it("creates new nodes with new UUIDs", async () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 200 },
});
mockReadText.mockResolvedValue(clipboardPayload([node], []));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
});
const { nodes } = useNodeStore.getState();
expect(nodes[0].id).toBe("new-uuid-1");
expect(nodes[0].id).not.toBe("orig");
});
it("centers pasted nodes in the current viewport", async () => {
// Viewport at origin, zoom 1 => center = (500, 400)
mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 });
const node = createTestNode("orig", {
selected: true,
position: { x: 100, y: 100 },
});
mockReadText.mockResolvedValue(clipboardPayload([node], []));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
});
const { nodes } = useNodeStore.getState();
// Single node: center of bounds = (100, 100)
// Viewport center = (500, 400)
// Offset = (400, 300)
// New position = (100 + 400, 100 + 300) = (500, 400)
expect(nodes[0].position).toEqual({ x: 500, y: 400 });
});
it("deselects existing nodes and selects pasted nodes", async () => {
const existingNode = createTestNode("existing", {
selected: true,
position: { x: 0, y: 0 },
});
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
const nodeToPaste = createTestNode("paste-me", {
selected: false,
position: { x: 100, y: 100 },
});
mockReadText.mockResolvedValue(clipboardPayload([nodeToPaste], []));
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
});
const { nodes } = useNodeStore.getState();
const originalNode = nodes.find((n) => n.id === "existing");
const pastedNode = nodes.find((n) => n.id !== "existing");
expect(originalNode!.selected).toBe(false);
expect(pastedNode!.selected).toBe(true);
});
it("remaps edge source/target IDs to newly created node IDs", async () => {
const nodeA = createTestNode("a", {
selected: true,
position: { x: 0, y: 0 },
});
const nodeB = createTestNode("b", {
selected: true,
position: { x: 200, y: 0 },
});
const edge = createTestEdge("e-ab", "a", "b", "output", "input");
mockReadText.mockResolvedValue(clipboardPayload([nodeA, nodeB], [edge]));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
useEdgeStore.setState({ edges: [] });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(2);
});
// Wait for edges to be added too
await vi.waitFor(() => {
const { edges } = useEdgeStore.getState();
expect(edges).toHaveLength(1);
});
const { edges } = useEdgeStore.getState();
const newEdge = edges[0];
// Edge source/target should be remapped to new UUIDs, not "a"/"b"
expect(newEdge.source).not.toBe("a");
expect(newEdge.target).not.toBe("b");
expect(newEdge.source).toBe("new-uuid-1");
expect(newEdge.target).toBe("new-uuid-2");
expect(newEdge.sourceHandle).toBe("output");
expect(newEdge.targetHandle).toBe("input");
});
it("does nothing when clipboard does not have the expected prefix", async () => {
mockReadText.mockResolvedValue("some random text");
const existingNode = createTestNode("1", { position: { x: 0, y: 0 } });
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
// Give async operations time to settle
await vi.waitFor(() => {
expect(mockReadText).toHaveBeenCalled();
});
// Ensure no state changes happen after clipboard read
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
});
it("does nothing when clipboard is empty", async () => {
mockReadText.mockResolvedValue("");
const existingNode = createTestNode("1", { position: { x: 0, y: 0 } });
useNodeStore.setState({ nodes: [existingNode], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
await vi.waitFor(() => {
expect(mockReadText).toHaveBeenCalled();
});
// Ensure no state changes happen after clipboard read
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
expect(nodes[0].id).toBe("1");
});
});
});
describe("input field focus guard", () => {
it("ignores Ctrl+C when an input element is focused", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const input = document.createElement("input");
document.body.appendChild(input);
input.focus();
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
// Clipboard write should NOT be called
expect(mockWriteText).not.toHaveBeenCalled();
document.body.removeChild(input);
});
it("ignores Ctrl+V when a textarea element is focused", async () => {
mockReadText.mockResolvedValue(
clipboardPayload(
[createTestNode("a", { position: { x: 0, y: 0 } })],
[],
),
);
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const textarea = document.createElement("textarea");
document.body.appendChild(textarea);
textarea.focus();
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makePasteEvent());
});
expect(mockReadText).not.toHaveBeenCalled();
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(0);
document.body.removeChild(textarea);
});
it("ignores keypresses when a contenteditable element is focused", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const div = document.createElement("div");
div.setAttribute("contenteditable", "true");
document.body.appendChild(div);
div.focus();
const { result } = renderHook(() => useCopyPaste());
act(() => {
result.current(makeCopyEvent());
});
expect(mockWriteText).not.toHaveBeenCalled();
document.body.removeChild(div);
});
});
describe("meta key support (macOS)", () => {
it("handles Cmd+C (metaKey) the same as Ctrl+C", async () => {
const node = createTestNode("1", { selected: true });
useNodeStore.setState({ nodes: [node] });
const { result } = renderHook(() => useCopyPaste());
const metaCopyEvent = new KeyboardEvent("keydown", {
key: "c",
metaKey: true,
bubbles: true,
});
act(() => {
result.current(metaCopyEvent);
});
await vi.waitFor(() => {
expect(mockWriteText).toHaveBeenCalledTimes(1);
});
});
it("handles Cmd+V (metaKey) the same as Ctrl+V", async () => {
const node = createTestNode("orig", {
selected: true,
position: { x: 0, y: 0 },
});
mockReadText.mockResolvedValue(clipboardPayload([node], []));
useNodeStore.setState({ nodes: [], nodeCounter: 0 });
const { result } = renderHook(() => useCopyPaste());
const metaPasteEvent = new KeyboardEvent("keydown", {
key: "v",
metaKey: true,
bubbles: true,
});
act(() => {
result.current(metaPasteEvent);
});
await vi.waitFor(() => {
const { nodes } = useNodeStore.getState();
expect(nodes).toHaveLength(1);
});
});
});
});

View File

@@ -0,0 +1,134 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
const mockScreenToFlowPosition = vi.fn((pos: { x: number; y: number }) => pos);
const mockFitView = vi.fn();
vi.mock("@xyflow/react", async () => {
const actual = await vi.importActual("@xyflow/react");
return {
...actual,
useReactFlow: () => ({
screenToFlowPosition: mockScreenToFlowPosition,
fitView: mockFitView,
}),
};
});
const mockSetQueryStates = vi.fn();
let mockQueryStateValues: {
flowID: string | null;
flowVersion: number | null;
flowExecutionID: string | null;
} = {
flowID: null,
flowVersion: null,
flowExecutionID: null,
};
vi.mock("nuqs", () => ({
parseAsString: {},
parseAsInteger: {},
useQueryStates: vi.fn(() => [mockQueryStateValues, mockSetQueryStates]),
}));
let mockGraphLoading = false;
let mockBlocksLoading = false;
vi.mock("@/app/api/__generated__/endpoints/graphs/graphs", () => ({
useGetV1GetSpecificGraph: vi.fn(() => ({
data: undefined,
isLoading: mockGraphLoading,
})),
useGetV1GetExecutionDetails: vi.fn(() => ({
data: undefined,
})),
useGetV1ListUserGraphs: vi.fn(() => ({
data: undefined,
})),
}));
vi.mock("@/app/api/__generated__/endpoints/default/default", () => ({
useGetV2GetSpecificBlocks: vi.fn(() => ({
data: undefined,
isLoading: mockBlocksLoading,
})),
}));
vi.mock("@/app/api/helpers", () => ({
okData: (res: { data: unknown }) => res?.data,
}));
vi.mock("../components/helper", () => ({
convertNodesPlusBlockInfoIntoCustomNodes: vi.fn(),
}));
describe("useFlow", () => {
beforeEach(() => {
vi.clearAllMocks();
vi.useFakeTimers({ shouldAdvanceTime: true });
mockGraphLoading = false;
mockBlocksLoading = false;
mockQueryStateValues = {
flowID: null,
flowVersion: null,
flowExecutionID: null,
};
});
afterEach(() => {
vi.useRealTimers();
});
describe("loading states", () => {
it("returns isFlowContentLoading true when graph is loading", async () => {
mockGraphLoading = true;
mockQueryStateValues = {
flowID: "test-flow",
flowVersion: 1,
flowExecutionID: null,
};
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isFlowContentLoading).toBe(true);
});
it("returns isFlowContentLoading true when blocks are loading", async () => {
mockBlocksLoading = true;
mockQueryStateValues = {
flowID: "test-flow",
flowVersion: 1,
flowExecutionID: null,
};
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isFlowContentLoading).toBe(true);
});
it("returns isFlowContentLoading false when neither is loading", async () => {
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isFlowContentLoading).toBe(false);
});
});
describe("initial load completion", () => {
it("marks initial load complete for new flows without flowID", async () => {
const { useFlow } = await import("../components/FlowEditor/Flow/useFlow");
const { result } = renderHook(() => useFlow());
expect(result.current.isInitialLoadComplete).toBe(false);
await act(async () => {
vi.advanceTimersByTime(300);
});
expect(result.current.isInitialLoadComplete).toBe(true);
});
});
});

View File

@@ -1,3 +1,4 @@
import { useCopilotUIStore } from "@/app/(platform)/copilot/store";
import { ChangeEvent, FormEvent, useEffect, useState } from "react";
interface Args {
@@ -16,6 +17,16 @@ export function useChatInput({
}: Args) {
const [value, setValue] = useState("");
const [isSending, setIsSending] = useState(false);
const { initialPrompt, setInitialPrompt } = useCopilotUIStore();
useEffect(
function consumeInitialPrompt() {
if (!initialPrompt) return;
setValue((prev) => (prev.length === 0 ? initialPrompt : prev));
setInitialPrompt(null);
},
[initialPrompt, setInitialPrompt],
);
useEffect(
function focusOnMount() {

View File

@@ -23,22 +23,25 @@ import {
useSidebar,
} from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { DotsThree, PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import {
CheckCircle,
DotsThree,
PlusCircleIcon,
PlusIcon,
} from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query";
import { motion } from "framer-motion";
import { AnimatePresence, motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
import { useEffect, useRef, useState } from "react";
import { getSessionListParams } from "../../helpers";
import { useCopilotUIStore } from "../../store";
import { SessionListItem } from "../SessionListItem/SessionListItem";
import { NotificationToggle } from "./components/NotificationToggle/NotificationToggle";
import { DeleteChatDialog } from "../DeleteChatDialog/DeleteChatDialog";
import { PulseLoader } from "../PulseLoader/PulseLoader";
export function ChatSidebar() {
const { state } = useSidebar();
const isCollapsed = state === "collapsed";
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
const listSessionsParams = getSessionListParams();
const {
sessionToDelete,
setSessionToDelete,
@@ -49,9 +52,7 @@ export function ChatSidebar() {
const queryClient = useQueryClient();
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions(listSessionsParams, {
query: { refetchInterval: 10_000 },
});
useGetV2ListSessions({ limit: 50 }, { query: { refetchInterval: 10_000 } });
const { mutate: deleteSession, isPending: isDeleting } =
useDeleteV2DeleteSession({
@@ -179,6 +180,31 @@ export function ChatSidebar() {
}
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
return (
<>
<Sidebar
@@ -269,17 +295,17 @@ export function ChatSidebar() {
No conversations yet
</p>
) : (
sessions.map((session) =>
editingSessionId === session.id ? (
<div
key={session.id}
className={cn(
"group relative w-full rounded-lg transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
sessions.map((session) => (
<div
key={session.id}
className={cn(
"group relative w-full rounded-lg transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
{editingSessionId === session.id ? (
<div className="px-3 py-2.5">
<input
ref={renameInputRef}
@@ -305,49 +331,87 @@ export function ChatSidebar() {
className="w-full rounded border border-zinc-300 bg-white px-2 py-1 text-sm text-zinc-800 outline-none focus:border-purple-500 focus:ring-1 focus:ring-purple-500"
/>
</div>
</div>
) : (
<SessionListItem
key={session.id}
session={session}
currentSessionId={sessionId}
isCompleted={completedSessionIDs.has(session.id)}
onSelect={handleSelectSession}
variant="sidebar"
actionSlot={
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
onClick={(e) => e.stopPropagation()}
className="rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
aria-label="More actions"
) : (
<button
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full items-center gap-2">
<div className="min-w-0 flex-1">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
<DotsThree className="h-4 w-4" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={(e) =>
handleRenameClick(e, session.id, session.title)
}
>
Rename
</DropdownMenuItem>
<DropdownMenuItem
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
}
/>
),
)
<AnimatePresence mode="wait" initial={false}>
<motion.span
key={session.title || "untitled"}
initial={{ opacity: 0, y: 4 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -4 }}
transition={{ duration: 0.2 }}
className="block truncate"
>
{session.title || "Untitled chat"}
</motion.span>
</AnimatePresence>
</Text>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
{session.is_processing &&
session.id !== sessionId &&
!completedSessionIDs.has(session.id) && (
<PulseLoader size={16} className="shrink-0" />
)}
{completedSessionIDs.has(session.id) &&
session.id !== sessionId && (
<CheckCircle
className="h-4 w-4 shrink-0 text-green-500"
weight="fill"
/>
)}
</div>
</button>
)}
{editingSessionId !== session.id && (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
onClick={(e) => e.stopPropagation()}
className="absolute right-2 top-1/2 -translate-y-1/2 rounded-full p-1.5 text-zinc-600 transition-all hover:bg-neutral-100"
aria-label="More actions"
>
<DotsThree className="h-4 w-4" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
onClick={(e) =>
handleRenameClick(e, session.id, session.title)
}
>
Rename
</DropdownMenuItem>
<DropdownMenuItem
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="text-red-600 focus:bg-red-50 focus:text-red-600"
>
Delete chat
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
)}
</div>
))
)}
</motion.div>
)}

View File

@@ -1,7 +1,10 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import {
CheckCircle,
PlusIcon,
SpeakerHigh,
SpeakerSlash,
@@ -9,9 +12,8 @@ import {
X,
} from "@phosphor-icons/react";
import { Drawer } from "vaul";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { useCopilotUIStore } from "../../store";
import { SessionListItem } from "../SessionListItem/SessionListItem";
import { PulseLoader } from "../PulseLoader/PulseLoader";
interface Props {
isOpen: boolean;
@@ -24,6 +26,31 @@ interface Props {
onOpenChange: (open: boolean) => void;
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
export function MobileDrawer({
isOpen,
sessions,
@@ -107,19 +134,52 @@ export function MobileDrawer({
</p>
) : (
sessions.map((session) => (
<SessionListItem
<button
key={session.id}
session={session}
currentSessionId={currentSessionId}
isCompleted={completedSessionIDs.has(session.id)}
variant="drawer"
onSelect={(selectedSessionId) => {
onSelectSession(selectedSessionId);
if (completedSessionIDs.has(selectedSessionId)) {
clearCompletedSession(selectedSessionId);
onClick={() => {
onSelectSession(session.id);
if (completedSessionIDs.has(session.id)) {
clearCompletedSession(session.id);
}
}}
/>
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === currentSessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="flex min-w-0 max-w-full items-center gap-1.5">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === currentSessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || "Untitled chat"}
</Text>
{session.is_processing &&
!completedSessionIDs.has(session.id) &&
session.id !== currentSessionId && (
<PulseLoader size={8} className="shrink-0" />
)}
{completedSessionIDs.has(session.id) &&
session.id !== currentSessionId && (
<CheckCircle
className="h-4 w-4 shrink-0 text-green-500"
weight="fill"
/>
)}
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
)}
</div>

View File

@@ -1,148 +0,0 @@
"use client";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Badge } from "@/components/atoms/Badge/Badge";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
import { CheckCircle } from "@phosphor-icons/react";
import { AnimatePresence, motion } from "framer-motion";
import type { ReactNode } from "react";
import {
formatSessionDate,
getSessionStartTypeLabel,
isNonManualSessionStartType,
} from "../../helpers";
import { PulseLoader } from "../PulseLoader/PulseLoader";
interface Props {
actionSlot?: ReactNode;
currentSessionId: string | null;
isCompleted: boolean;
onSelect: (sessionId: string) => void;
session: SessionSummaryResponse;
variant?: "sidebar" | "drawer";
}
export function SessionListItem({
actionSlot,
currentSessionId,
isCompleted,
onSelect,
session,
variant = "sidebar",
}: Props) {
const isActive = session.id === currentSessionId;
const showProcessing = session.is_processing && !isCompleted && !isActive;
const showCompleted = isCompleted && !isActive;
const startTypeLabel = isNonManualSessionStartType(session.start_type)
? getSessionStartTypeLabel(session.start_type)
: null;
if (variant === "drawer") {
return (
<button
onClick={() => onSelect(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="flex min-w-0 max-w-full items-center gap-1.5">
<Text
variant="body"
className={cn(
"truncate font-normal",
isActive ? "text-zinc-600" : "text-zinc-800",
)}
>
{session.title || "Untitled chat"}
</Text>
{showProcessing ? (
<PulseLoader size={8} className="shrink-0" />
) : null}
{showCompleted ? (
<CheckCircle
className="h-4 w-4 shrink-0 text-green-500"
weight="fill"
/>
) : null}
</div>
{startTypeLabel ? (
<div className="mt-1">
<Badge variant="info" size="small">
{startTypeLabel}
</Badge>
</div>
) : null}
<Text variant="small" className="text-neutral-400">
{formatSessionDate(session.updated_at)}
</Text>
</div>
</button>
);
}
return (
<div
className={cn(
"group relative w-full rounded-lg transition-colors",
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
)}
>
<button
onClick={() => onSelect(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full items-center gap-2">
<div className="min-w-0 flex-1">
<Text
variant="body"
className={cn(
"truncate font-normal",
isActive ? "text-zinc-600" : "text-zinc-800",
)}
>
<AnimatePresence mode="wait" initial={false}>
<motion.span
key={session.title || "untitled"}
initial={{ opacity: 0, y: 4 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -4 }}
transition={{ duration: 0.2 }}
className="block truncate"
>
{session.title || "Untitled chat"}
</motion.span>
</AnimatePresence>
</Text>
{startTypeLabel ? (
<div className="mt-1">
<Badge variant="info" size="small">
{startTypeLabel}
</Badge>
</div>
) : null}
<Text variant="small" className="text-neutral-400">
{formatSessionDate(session.updated_at)}
</Text>
</div>
{showProcessing ? (
<PulseLoader size={16} className="shrink-0" />
) : null}
{showCompleted ? (
<CheckCircle
className="h-4 w-4 shrink-0 text-green-500"
weight="fill"
/>
) : null}
</div>
</button>
{actionSlot ? (
<div className="absolute right-2 top-1/2 -translate-y-1/2">
{actionSlot}
</div>
) : null}
</div>
);
}

View File

@@ -1,65 +1,5 @@
import type { GetV2ListSessionsParams } from "@/app/api/__generated__/models/getV2ListSessionsParams";
import {
ChatSessionStartType,
type ChatSessionStartType as ChatSessionStartTypeValue,
} from "@/app/api/__generated__/models/chatSessionStartType";
import type { UIMessage } from "ai";
export const COPILOT_SESSION_LIST_LIMIT = 50;
export function getSessionListParams(): GetV2ListSessionsParams {
return {
limit: COPILOT_SESSION_LIST_LIMIT,
with_auto: true,
};
}
export function isNonManualSessionStartType(
startType: ChatSessionStartTypeValue | null | undefined,
): boolean {
return startType != null && startType !== ChatSessionStartType.MANUAL;
}
export function getSessionStartTypeLabel(
startType: ChatSessionStartTypeValue,
): string | null {
switch (startType) {
case ChatSessionStartType.AUTOPILOT_NIGHTLY:
return "Nightly";
case ChatSessionStartType.AUTOPILOT_CALLBACK:
return "Callback";
case ChatSessionStartType.AUTOPILOT_INVITE_CTA:
return "Invite CTA";
default:
return null;
}
}
export function formatSessionDate(dateString: string): string {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
/** Mark any in-progress tool parts as completed/errored so spinners stop. */
export function resolveInProgressTools(
messages: UIMessage[],

View File

@@ -7,6 +7,10 @@ export interface DeleteTarget {
}
interface CopilotUIState {
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
initialPrompt: string | null;
setInitialPrompt: (prompt: string | null) => void;
sessionToDelete: DeleteTarget | null;
setSessionToDelete: (target: DeleteTarget | null) => void;
@@ -31,6 +35,9 @@ interface CopilotUIState {
}
export const useCopilotUIStore = create<CopilotUIState>((set) => ({
initialPrompt: null,
setInitialPrompt: (prompt) => set({ initialPrompt: prompt }),
sessionToDelete: null,
setSessionToDelete: (target) => set({ sessionToDelete: target }),

View File

@@ -1,93 +0,0 @@
import { usePostV2ConsumeCallbackTokenRoute } from "@/app/api/__generated__/endpoints/chat/chat";
import { toast } from "@/components/molecules/Toast/use-toast";
import { useQueryClient } from "@tanstack/react-query";
import { parseAsString, useQueryState } from "nuqs";
import { useEffect, useState } from "react";
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
interface Props {
isLoggedIn: boolean;
onConsumed: (sessionId: string) => void;
onClearAutopilot: () => void;
}
export function useCallbackToken({
isLoggedIn,
onConsumed,
onClearAutopilot,
}: Props) {
const queryClient = useQueryClient();
const [callbackToken, setCallbackToken] = useQueryState(
"callbackToken",
parseAsString,
);
const [consumedTokens, setConsumedTokens] = useState<Set<string>>(
() => new Set(),
);
const { mutateAsync: consumeCallbackToken, isPending } =
usePostV2ConsumeCallbackTokenRoute();
const hasConsumedToken =
callbackToken != null && consumedTokens.has(callbackToken);
useEffect(() => {
if (!isLoggedIn || !callbackToken || hasConsumedToken) {
return;
}
let isCancelled = false;
const token = callbackToken;
setConsumedTokens((current) => new Set(current).add(token));
void consumeCallbackToken({ data: { token } })
.then((response) => {
if (isCancelled) {
return;
}
if (response.status !== 200 || !response.data?.session_id) {
throw new Error("Failed to open callback session");
}
onConsumed(response.data.session_id);
onClearAutopilot();
void setCallbackToken(null);
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
})
.catch((error) => {
if (isCancelled) {
return;
}
setConsumedTokens((current) => {
const next = new Set(current);
next.delete(token);
return next;
});
void setCallbackToken(null);
toast({
title: "Unable to open callback session",
description:
error instanceof Error ? error.message : "Please try again.",
variant: "destructive",
});
});
return () => {
isCancelled = true;
};
}, [
callbackToken,
consumeCallbackToken,
hasConsumedToken,
isLoggedIn,
onClearAutopilot,
onConsumed,
queryClient,
setCallbackToken,
]);
return {
isConsumingCallbackToken: isPending,
};
}

View File

@@ -4,7 +4,6 @@ import {
useGetV2GetSession,
usePostV2CreateSession,
} from "@/app/api/__generated__/endpoints/chat/chat";
import type { ChatSessionStartType } from "@/app/api/__generated__/models/chatSessionStartType";
import { toast } from "@/components/molecules/Toast/use-toast";
import * as Sentry from "@sentry/nextjs";
import { useQueryClient } from "@tanstack/react-query";
@@ -71,14 +70,6 @@ export function useChatSession() {
);
}, [sessionQuery.data, sessionId, hasActiveStream]);
const sessionStartType = useMemo<ChatSessionStartType | null>(() => {
if (sessionQuery.data?.status !== 200) {
return null;
}
return sessionQuery.data.data.start_type;
}, [sessionQuery.data]);
const { mutateAsync: createSessionMutation, isPending: isCreatingSession } =
usePostV2CreateSession({
mutation: {
@@ -130,7 +121,6 @@ export function useChatSession() {
return {
sessionId,
setSessionId,
sessionStartType,
hydratedMessages,
hasActiveStream,
isLoadingSession: sessionQuery.isLoading,

View File

@@ -2,24 +2,70 @@ import {
getGetV2ListSessionsQueryKey,
useDeleteV2DeleteSession,
useGetV2ListSessions,
type getV2ListSessionsResponse,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { toast } from "@/components/molecules/Toast/use-toast";
import { uploadFileDirect } from "@/lib/direct-upload";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { getSessionListParams } from "./helpers";
import type { FileUIPart } from "ai";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "./store";
import { useCallbackToken } from "./useCallbackToken";
import { useChatSession } from "./useChatSession";
import { useFileUpload } from "./useFileUpload";
import { useCopilotNotifications } from "./useCopilotNotifications";
import { useCopilotStream } from "./useCopilotStream";
import { useTitlePolling } from "./useTitlePolling";
const TITLE_POLL_INTERVAL_MS = 2_000;
const TITLE_POLL_MAX_ATTEMPTS = 5;
/**
* Extract a prompt from the URL hash fragment.
* Supports: /copilot#prompt=URL-encoded-text
* Optionally auto-submits if ?autosubmit=true is in the query string.
* Returns null if no prompt is present.
*/
function extractPromptFromUrl(): {
prompt: string;
autosubmit: boolean;
} | null {
if (typeof window === "undefined") return null;
const hash = window.location.hash;
if (!hash) return null;
const hashParams = new URLSearchParams(hash.slice(1));
const prompt = hashParams.get("prompt");
if (!prompt || !prompt.trim()) return null;
const searchParams = new URLSearchParams(window.location.search);
const autosubmit = searchParams.get("autosubmit") === "true";
// Clean up hash + autosubmit param only (preserve other query params)
const cleanURL = new URL(window.location.href);
cleanURL.hash = "";
cleanURL.searchParams.delete("autosubmit");
window.history.replaceState(
null,
"",
`${cleanURL.pathname}${cleanURL.search}`,
);
return { prompt: prompt.trim(), autosubmit };
}
interface UploadedFile {
file_id: string;
name: string;
mime_type: string;
}
export function useCopilotPage() {
const { isUserLoading, isLoggedIn } = useSupabase();
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
const queryClient = useQueryClient();
const listSessionsParams = getSessionListParams();
const { sessionToDelete, setSessionToDelete, isDrawerOpen, setDrawerOpen } =
useCopilotUIStore();
@@ -29,7 +75,7 @@ export function useCopilotPage() {
setSessionId,
hydratedMessages,
hasActiveStream,
isLoadingSession: isLoadingCurrentSession,
isLoadingSession,
isSessionError,
createSession,
isCreatingSession,
@@ -83,33 +129,220 @@ export function useCopilotPage() {
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const { isConsumingCallbackToken } = useCallbackToken({
isLoggedIn,
onConsumed: setSessionId,
onClearAutopilot() {},
});
const pendingFilesRef = useRef<File[]>([]);
const { isUploadingFiles, onSend } = useFileUpload({
createSession,
isUserStoppingRef,
sendMessage,
sessionId,
});
// --- Send pending message after session creation ---
useEffect(() => {
if (!sessionId || pendingMessage === null) return;
const msg = pendingMessage;
const files = pendingFilesRef.current;
setPendingMessage(null);
pendingFilesRef.current = [];
if (files.length > 0) {
setIsUploadingFiles(true);
void uploadFiles(files, sessionId)
.then((uploaded) => {
if (uploaded.length === 0) {
toast({
title: "File upload failed",
description: "Could not upload any files. Please try again.",
variant: "destructive",
});
return;
}
const fileParts = buildFileParts(uploaded);
sendMessage({
text: msg,
files: fileParts.length > 0 ? fileParts : undefined,
});
})
.finally(() => setIsUploadingFiles(false));
} else {
sendMessage({ text: msg });
}
}, [sessionId, pendingMessage, sendMessage]);
// --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) ---
const { setInitialPrompt } = useCopilotUIStore();
const hasProcessedUrlPrompt = useRef(false);
useEffect(() => {
if (hasProcessedUrlPrompt.current) return;
const urlPrompt = extractPromptFromUrl();
if (!urlPrompt) return;
hasProcessedUrlPrompt.current = true;
if (urlPrompt.autosubmit) {
setPendingMessage(urlPrompt.prompt);
void createSession().catch(() => {
setPendingMessage(null);
setInitialPrompt(urlPrompt.prompt);
});
} else {
setInitialPrompt(urlPrompt.prompt);
}
}, [createSession, setInitialPrompt]);
async function uploadFiles(
files: File[],
sid: string,
): Promise<UploadedFile[]> {
const results = await Promise.allSettled(
files.map(async (file) => {
try {
const data = await uploadFileDirect(file, sid);
if (!data.file_id) throw new Error("No file_id returned");
return {
file_id: data.file_id,
name: data.name || file.name,
mime_type: data.mime_type || "application/octet-stream",
} as UploadedFile;
} catch (err) {
console.error("File upload failed:", err);
toast({
title: "File upload failed",
description: file.name,
variant: "destructive",
});
throw err;
}
}),
);
return results
.filter(
(r): r is PromiseFulfilledResult<UploadedFile> =>
r.status === "fulfilled",
)
.map((r) => r.value);
}
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
return uploaded.map((f) => ({
type: "file" as const,
mediaType: f.mime_type,
filename: f.name,
url: `/api/proxy/api/workspace/files/${f.file_id}/download`,
}));
}
async function onSend(message: string, files?: File[]) {
const trimmed = message.trim();
if (!trimmed && (!files || files.length === 0)) return;
// Client-side file limits
if (files && files.length > 0) {
const MAX_FILES = 10;
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024; // 100 MB
if (files.length > MAX_FILES) {
toast({
title: "Too many files",
description: `You can attach up to ${MAX_FILES} files at once.`,
variant: "destructive",
});
return;
}
const oversized = files.filter((f) => f.size > MAX_FILE_SIZE_BYTES);
if (oversized.length > 0) {
toast({
title: "File too large",
description: `${oversized[0].name} exceeds the 100 MB limit.`,
variant: "destructive",
});
return;
}
}
isUserStoppingRef.current = false;
if (sessionId) {
if (files && files.length > 0) {
setIsUploadingFiles(true);
try {
const uploaded = await uploadFiles(files, sessionId);
if (uploaded.length === 0) {
// All uploads failed — abort send so chips revert to editable
throw new Error("All file uploads failed");
}
const fileParts = buildFileParts(uploaded);
sendMessage({
text: trimmed || "",
files: fileParts.length > 0 ? fileParts : undefined,
});
} finally {
setIsUploadingFiles(false);
}
} else {
sendMessage({ text: trimmed });
}
return;
}
setPendingMessage(trimmed || "");
if (files && files.length > 0) {
pendingFilesRef.current = files;
}
await createSession();
}
// --- Session list (for mobile drawer & sidebar) ---
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions(listSessionsParams, {
query: { enabled: !isUserLoading && isLoggedIn },
});
useGetV2ListSessions(
{ limit: 50 },
{ query: { enabled: !isUserLoading && isLoggedIn } },
);
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
useTitlePolling({
isReconnecting,
sessionId,
status,
});
// Start title polling when stream ends cleanly — sidebar title animates in
const titlePollRef = useRef<ReturnType<typeof setInterval>>();
const prevStatusRef = useRef(status);
useEffect(() => {
const prev = prevStatusRef.current;
prevStatusRef.current = status;
const wasActive = prev === "streaming" || prev === "submitted";
const isNowReady = status === "ready";
if (!wasActive || !isNowReady || !sessionId || isReconnecting) return;
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
});
const sid = sessionId;
let attempts = 0;
clearInterval(titlePollRef.current);
titlePollRef.current = setInterval(() => {
const data = queryClient.getQueryData<getV2ListSessionsResponse>(
getGetV2ListSessionsQueryKey({ limit: 50 }),
);
const hasTitle =
data?.status === 200 &&
data.data.sessions.some((s) => s.id === sid && s.title);
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
clearInterval(titlePollRef.current);
titlePollRef.current = undefined;
return;
}
attempts += 1;
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey({ limit: 50 }),
});
}, TITLE_POLL_INTERVAL_MS);
}, [status, sessionId, isReconnecting, queryClient]);
// Clean up polling on session change or unmount
useEffect(() => {
return () => {
clearInterval(titlePollRef.current);
titlePollRef.current = undefined;
};
}, [sessionId]);
// --- Mobile drawer handlers ---
function handleOpenDrawer() {
@@ -159,7 +392,7 @@ export function useCopilotPage() {
error,
stop,
isReconnecting,
isLoadingSession: isLoadingCurrentSession || isConsumingCallbackToken,
isLoadingSession,
isSessionError,
isCreatingSession,
isUploadingFiles,

View File

@@ -1,178 +0,0 @@
import { uploadFileDirect } from "@/lib/direct-upload";
import type { FileUIPart } from "ai";
import { toast } from "@/components/molecules/Toast/use-toast";
import { useEffect, useRef, useState, type MutableRefObject } from "react";
const MAX_FILES = 10;
const MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024;
interface UploadedFile {
file_id: string;
name: string;
mime_type: string;
}
interface SendMessageInput {
text: string;
files?: FileUIPart[];
}
interface Props {
createSession: () => Promise<unknown>;
isUserStoppingRef: MutableRefObject<boolean>;
sendMessage: (input: SendMessageInput) => void;
sessionId: string | null;
}
async function uploadFiles(
files: File[],
sessionId: string,
): Promise<UploadedFile[]> {
const results = await Promise.allSettled(
files.map(async (file) => {
try {
const data = await uploadFileDirect(file, sessionId);
if (!data.file_id) throw new Error("No file_id returned");
return {
file_id: data.file_id,
name: data.name || file.name,
mime_type: data.mime_type || "application/octet-stream",
} satisfies UploadedFile;
} catch (error) {
console.error("File upload failed:", error);
toast({
title: "File upload failed",
description: file.name,
variant: "destructive",
});
throw error;
}
}),
);
return results
.filter(
(result): result is PromiseFulfilledResult<UploadedFile> =>
result.status === "fulfilled",
)
.map((result) => result.value);
}
function buildFileParts(uploaded: UploadedFile[]): FileUIPart[] {
return uploaded.map((file) => ({
type: "file" as const,
mediaType: file.mime_type,
filename: file.name,
url: `/api/proxy/api/workspace/files/${file.file_id}/download`,
}));
}
export function useFileUpload({
createSession,
isUserStoppingRef,
sendMessage,
sessionId,
}: Props) {
const [isUploadingFiles, setIsUploadingFiles] = useState(false);
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
const pendingFilesRef = useRef<File[]>([]);
useEffect(() => {
if (!sessionId || pendingMessage === null) {
return;
}
const message = pendingMessage;
const files = pendingFilesRef.current;
setPendingMessage(null);
pendingFilesRef.current = [];
if (files.length === 0) {
sendMessage({ text: message });
return;
}
setIsUploadingFiles(true);
void uploadFiles(files, sessionId)
.then((uploaded) => {
if (uploaded.length === 0) {
toast({
title: "File upload failed",
description: "Could not upload any files. Please try again.",
variant: "destructive",
});
return;
}
const fileParts = buildFileParts(uploaded);
sendMessage({
text: message,
files: fileParts.length > 0 ? fileParts : undefined,
});
})
.finally(() => setIsUploadingFiles(false));
}, [pendingMessage, sendMessage, sessionId]);
async function onSend(message: string, files?: File[]) {
const trimmed = message.trim();
if (!trimmed && (!files || files.length === 0)) {
return;
}
if (files && files.length > 0) {
if (files.length > MAX_FILES) {
toast({
title: "Too many files",
description: `You can attach up to ${MAX_FILES} files at once.`,
variant: "destructive",
});
return;
}
const oversized = files.filter((file) => file.size > MAX_FILE_SIZE_BYTES);
if (oversized.length > 0) {
toast({
title: "File too large",
description: `${oversized[0].name} exceeds the 100 MB limit.`,
variant: "destructive",
});
return;
}
}
isUserStoppingRef.current = false;
if (sessionId) {
if (!files || files.length === 0) {
sendMessage({ text: trimmed });
return;
}
setIsUploadingFiles(true);
try {
const uploaded = await uploadFiles(files, sessionId);
if (uploaded.length === 0) {
throw new Error("All file uploads failed");
}
const fileParts = buildFileParts(uploaded);
sendMessage({
text: trimmed || "",
files: fileParts.length > 0 ? fileParts : undefined,
});
} finally {
setIsUploadingFiles(false);
}
return;
}
setPendingMessage(trimmed || "");
pendingFilesRef.current = files ?? [];
await createSession();
}
return {
isUploadingFiles,
onSend,
};
}

View File

@@ -1,72 +0,0 @@
import {
getGetV2ListSessionsQueryKey,
type getV2ListSessionsResponse,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef } from "react";
import { getSessionListParams } from "./helpers";
const TITLE_POLL_INTERVAL_MS = 2_000;
const TITLE_POLL_MAX_ATTEMPTS = 5;
interface Props {
isReconnecting: boolean;
sessionId: string | null;
status: string;
}
export function useTitlePolling({ isReconnecting, sessionId, status }: Props) {
const queryClient = useQueryClient();
const previousStatusRef = useRef(status);
useEffect(() => {
const previousStatus = previousStatusRef.current;
previousStatusRef.current = status;
const wasActive =
previousStatus === "streaming" || previousStatus === "submitted";
const isNowReady = status === "ready";
if (!wasActive || !isNowReady || !sessionId || isReconnecting) {
return;
}
const params = getSessionListParams();
const queryKey = getGetV2ListSessionsQueryKey(params);
let attempts = 0;
let timeoutId: ReturnType<typeof setTimeout> | undefined;
let isCancelled = false;
const poll = () => {
if (isCancelled) {
return;
}
const data =
queryClient.getQueryData<getV2ListSessionsResponse>(queryKey);
const hasTitle =
data?.status === 200 &&
data.data.sessions.some(
(session) => session.id === sessionId && session.title,
);
if (hasTitle || attempts >= TITLE_POLL_MAX_ATTEMPTS) {
return;
}
attempts += 1;
queryClient.invalidateQueries({ queryKey });
timeoutId = setTimeout(poll, TITLE_POLL_INTERVAL_MS);
};
queryClient.invalidateQueries({ queryKey });
timeoutId = setTimeout(poll, TITLE_POLL_INTERVAL_MS);
return () => {
isCancelled = true;
if (timeoutId) {
clearTimeout(timeoutId);
}
};
}, [isReconnecting, queryClient, sessionId, status]);
}

View File

@@ -1030,16 +1030,6 @@
"default": 0,
"title": "Offset"
}
},
{
"name": "with_auto",
"in": "query",
"required": false,
"schema": {
"type": "boolean",
"default": false,
"title": "With Auto"
}
}
],
"responses": {
@@ -1089,47 +1079,6 @@
}
}
},
"/api/chat/sessions/callback-token/consume": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Consume Callback Token Route",
"operationId": "postV2ConsumeCallbackTokenRoute",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ConsumeCallbackTokenRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ConsumeCallbackTokenResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/sessions/{session_id}": {
"delete": {
"tags": ["v2", "chat", "chat"],
@@ -6721,145 +6670,6 @@
}
}
},
"/api/users/admin/copilot/send-emails": {
"post": {
"tags": ["v2", "admin", "users", "admin"],
"summary": "Send Pending Copilot Emails",
"operationId": "postV2SendPendingCopilotEmails",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SendCopilotEmailsRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SendCopilotEmailsResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/users/admin/copilot/trigger": {
"post": {
"tags": ["v2", "admin", "users", "admin"],
"summary": "Trigger Copilot Session",
"operationId": "postV2TriggerCopilotSession",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TriggerCopilotSessionRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/TriggerCopilotSessionResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/users/admin/copilot/users": {
"get": {
"tags": ["v2", "admin", "users", "admin"],
"summary": "Search Copilot Users",
"operationId": "getV2SearchCopilotUsers",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "search",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Search by email, name, or user ID",
"default": "",
"title": "Search"
},
"description": "Search by email, name, or user ID"
},
{
"name": "limit",
"in": "query",
"required": false,
"schema": {
"type": "integer",
"maximum": 50,
"minimum": 1,
"default": 20,
"title": "Limit"
}
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/AdminCopilotUsersResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/users/admin/invited-users": {
"get": {
"tags": ["v2", "admin", "users", "admin"],
@@ -7460,42 +7270,6 @@
"required": ["new_balance", "transaction_key"],
"title": "AddUserCreditsResponse"
},
"AdminCopilotUserSummary": {
"properties": {
"id": { "type": "string", "title": "Id" },
"email": { "type": "string", "title": "Email" },
"name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Name"
},
"timezone": { "type": "string", "title": "Timezone" },
"created_at": {
"type": "string",
"format": "date-time",
"title": "Created At"
},
"updated_at": {
"type": "string",
"format": "date-time",
"title": "Updated At"
}
},
"type": "object",
"required": ["id", "email", "timezone", "created_at", "updated_at"],
"title": "AdminCopilotUserSummary"
},
"AdminCopilotUsersResponse": {
"properties": {
"users": {
"items": { "$ref": "#/components/schemas/AdminCopilotUserSummary" },
"type": "array",
"title": "Users"
}
},
"type": "object",
"required": ["users"],
"title": "AdminCopilotUsersResponse"
},
"AgentDetails": {
"properties": {
"id": { "type": "string", "title": "Id" },
@@ -7509,12 +7283,14 @@
"inputs": {
"additionalProperties": true,
"type": "object",
"title": "Inputs"
"title": "Inputs",
"default": {}
},
"credentials": {
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
"type": "array",
"title": "Credentials"
"title": "Credentials",
"default": []
},
"execution_options": {
"$ref": "#/components/schemas/ExecutionOptions"
@@ -8091,17 +7867,20 @@
"inputs": {
"additionalProperties": true,
"type": "object",
"title": "Inputs"
"title": "Inputs",
"default": {}
},
"outputs": {
"additionalProperties": true,
"type": "object",
"title": "Outputs"
"title": "Outputs",
"default": {}
},
"credentials": {
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
"type": "array",
"title": "Credentials"
"title": "Credentials",
"default": []
}
},
"type": "object",
@@ -8640,16 +8419,6 @@
"required": ["query", "conversation_history", "message_id"],
"title": "ChatRequest"
},
"ChatSessionStartType": {
"type": "string",
"enum": [
"MANUAL",
"AUTOPILOT_NIGHTLY",
"AUTOPILOT_CALLBACK",
"AUTOPILOT_INVITE_CTA"
],
"title": "ChatSessionStartType"
},
"ClarificationNeededResponse": {
"properties": {
"type": {
@@ -8686,20 +8455,6 @@
"title": "ClarifyingQuestion",
"description": "A question that needs user clarification."
},
"ConsumeCallbackTokenRequest": {
"properties": { "token": { "type": "string", "title": "Token" } },
"type": "object",
"required": ["token"],
"title": "ConsumeCallbackTokenRequest"
},
"ConsumeCallbackTokenResponse": {
"properties": {
"session_id": { "type": "string", "title": "Session Id" }
},
"type": "object",
"required": ["session_id"],
"title": "ConsumeCallbackTokenResponse"
},
"ContentType": {
"type": "string",
"enum": [
@@ -11123,7 +10878,8 @@
"suggestions": {
"items": { "type": "string" },
"type": "array",
"title": "Suggestions"
"title": "Suggestions",
"default": []
},
"name": { "type": "string", "title": "Name", "default": "no_results" }
},
@@ -12159,7 +11915,6 @@
"error",
"no_results",
"need_login",
"completion_report_saved",
"agents_found",
"agent_details",
"setup_requirements",
@@ -12416,37 +12171,6 @@
"required": ["items", "search_id", "total_items", "pagination"],
"title": "SearchResponse"
},
"SendCopilotEmailsRequest": {
"properties": { "user_id": { "type": "string", "title": "User Id" } },
"type": "object",
"required": ["user_id"],
"title": "SendCopilotEmailsRequest"
},
"SendCopilotEmailsResponse": {
"properties": {
"candidate_count": { "type": "integer", "title": "Candidate Count" },
"processed_count": { "type": "integer", "title": "Processed Count" },
"sent_count": { "type": "integer", "title": "Sent Count" },
"skipped_count": { "type": "integer", "title": "Skipped Count" },
"repair_queued_count": {
"type": "integer",
"title": "Repair Queued Count"
},
"running_count": { "type": "integer", "title": "Running Count" },
"failed_count": { "type": "integer", "title": "Failed Count" }
},
"type": "object",
"required": [
"candidate_count",
"processed_count",
"sent_count",
"skipped_count",
"repair_queued_count",
"running_count",
"failed_count"
],
"title": "SendCopilotEmailsResponse"
},
"SessionDetailResponse": {
"properties": {
"id": { "type": "string", "title": "Id" },
@@ -12456,11 +12180,6 @@
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Id"
},
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" },
"execution_tag": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Execution Tag"
},
"messages": {
"items": { "additionalProperties": true, "type": "object" },
"type": "array",
@@ -12474,14 +12193,7 @@
}
},
"type": "object",
"required": [
"id",
"created_at",
"updated_at",
"user_id",
"start_type",
"messages"
],
"required": ["id", "created_at", "updated_at", "user_id", "messages"],
"title": "SessionDetailResponse",
"description": "Response model providing complete details for a chat session, including messages."
},
@@ -12494,21 +12206,10 @@
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Title"
},
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" },
"execution_tag": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Execution Tag"
},
"is_processing": { "type": "boolean", "title": "Is Processing" }
},
"type": "object",
"required": [
"id",
"created_at",
"updated_at",
"start_type",
"is_processing"
],
"required": ["id", "created_at", "updated_at", "is_processing"],
"title": "SessionSummaryResponse",
"description": "Response model for a session summary (without messages)."
},
@@ -14139,24 +13840,6 @@
"required": ["transactions", "next_transaction_time"],
"title": "TransactionHistory"
},
"TriggerCopilotSessionRequest": {
"properties": {
"user_id": { "type": "string", "title": "User Id" },
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" }
},
"type": "object",
"required": ["user_id", "start_type"],
"title": "TriggerCopilotSessionRequest"
},
"TriggerCopilotSessionResponse": {
"properties": {
"session_id": { "type": "string", "title": "Session Id" },
"start_type": { "$ref": "#/components/schemas/ChatSessionStartType" }
},
"type": "object",
"required": ["session_id", "start_type"],
"title": "TriggerCopilotSessionResponse"
},
"TriggeredPresetSetupRequest": {
"properties": {
"name": { "type": "string", "title": "Name" },
@@ -15088,7 +14771,8 @@
"missing_credentials": {
"additionalProperties": true,
"type": "object",
"title": "Missing Credentials"
"title": "Missing Credentials",
"default": {}
},
"ready_to_run": {
"type": "boolean",

View File

@@ -0,0 +1,343 @@
# Workspace & Media File Architecture
This document describes the architecture for handling user files in AutoGPT Platform, covering persistent user storage (Workspace) and ephemeral media processing pipelines.
## Overview
The platform has two distinct file-handling layers:
| Layer | Purpose | Persistence | Scope |
|-------|---------|-------------|-------|
| **Workspace** | Long-term user file storage | Persistent (DB + GCS/local) | Per-user, session-scoped access |
| **Media Pipeline** | Ephemeral file processing for blocks | Temporary (local disk) | Per-execution |
## Database Models
### UserWorkspace
Represents a user's file storage space. Created on-demand (one per user).
```prisma
model UserWorkspace {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String @unique
Files UserWorkspaceFile[]
}
```
**Key points:**
- One workspace per user (enforced by `@unique` on `userId`)
- Created lazily via `get_or_create_workspace()`
- Uses upsert to handle race conditions
### UserWorkspaceFile
Represents a file stored in a user's workspace.
```prisma
model UserWorkspaceFile {
id String @id @default(uuid())
workspaceId String
name String // User-visible filename
path String // Virtual path (e.g., "/sessions/abc123/image.png")
storagePath String // Actual storage path (gcs://... or local://...)
mimeType String
sizeBytes BigInt
checksum String? // SHA256 for integrity
isDeleted Boolean @default(false)
deletedAt DateTime?
metadata Json @default("{}")
@@unique([workspaceId, path]) // Enforce unique paths within workspace
}
```
**Key points:**
- `path` is a virtual path for organizing files (not actual filesystem path)
- `storagePath` contains the actual GCS or local storage location
- Soft-delete pattern: `isDeleted` flag with `deletedAt` timestamp
- Path is modified on delete to free up the virtual path for reuse
---
## WorkspaceManager
**Location:** `backend/util/workspace.py`
High-level API for workspace file operations. Combines storage backend operations with database record management.
### Initialization
```python
from backend.util.workspace import WorkspaceManager
# Basic usage
manager = WorkspaceManager(user_id="user-123", workspace_id="ws-456")
# With session scoping (CoPilot sessions)
manager = WorkspaceManager(
user_id="user-123",
workspace_id="ws-456",
session_id="session-789"
)
```
### Session Scoping
When `session_id` is provided, files are isolated to `/sessions/{session_id}/`:
```python
# With session_id="abc123":
manager.write_file(content, "image.png")
# → stored at /sessions/abc123/image.png
# Cross-session access is explicit:
manager.read_file("/sessions/other-session/file.txt") # Works
```
**Why session scoping?**
- CoPilot conversations need file isolation
- Prevents file collisions between concurrent sessions
- Allows session cleanup without affecting other sessions
### Core Methods
| Method | Description |
|--------|-------------|
| `write_file(content, filename, path?, mime_type?, overwrite?)` | Write file to workspace |
| `read_file(path)` | Read file by virtual path |
| `read_file_by_id(file_id)` | Read file by ID |
| `list_files(path?, limit?, offset?, include_all_sessions?)` | List files |
| `delete_file(file_id)` | Soft-delete a file |
| `get_download_url(file_id, expires_in?)` | Get signed download URL |
| `get_file_info(file_id)` | Get file metadata |
| `get_file_info_by_path(path)` | Get file metadata by path |
| `get_file_count(path?, include_all_sessions?)` | Count files |
### Storage Backends
WorkspaceManager delegates to `WorkspaceStorageBackend`:
| Backend | When Used | Storage Path Format |
|---------|-----------|---------------------|
| `GCSWorkspaceStorage` | `media_gcs_bucket_name` is configured | `gcs://bucket/workspaces/{ws_id}/{file_id}/{filename}` |
| `LocalWorkspaceStorage` | No GCS bucket configured | `local://{ws_id}/{file_id}/{filename}` |
---
## store_media_file()
**Location:** `backend/util/file.py`
The media normalization pipeline. Handles various input types and normalizes them for processing or output.
### Purpose
Blocks receive files in many formats (URLs, data URIs, workspace references, local paths). `store_media_file()` normalizes these to a consistent format based on what the block needs.
### Input Types Handled
| Input Format | Example | How It's Processed |
|--------------|---------|-------------------|
| Data URI | `data:image/png;base64,iVBOR...` | Decoded, virus scanned, written locally |
| HTTP(S) URL | `https://example.com/image.png` | Downloaded, virus scanned, written locally |
| Workspace URI | `workspace://abc123` or `workspace:///path/to/file` | Read from workspace, virus scanned, written locally |
| Cloud path | `gcs://bucket/path` | Downloaded, virus scanned, written locally |
| Local path | `image.png` | Verified to exist in exec_file directory |
### Return Formats
The `return_format` parameter determines what you get back:
```python
from backend.util.file import store_media_file
# For local processing (ffmpeg, MoviePy, PIL)
local_path = await store_media_file(
file=input_file,
execution_context=ctx,
return_format="for_local_processing"
)
# Returns: "image.png" (relative path in exec_file dir)
# For external APIs (Replicate, OpenAI, etc.)
data_uri = await store_media_file(
file=input_file,
execution_context=ctx,
return_format="for_external_api"
)
# Returns: "data:image/png;base64,iVBOR..."
# For block output (adapts to execution context)
output = await store_media_file(
file=input_file,
execution_context=ctx,
return_format="for_block_output"
)
# In CoPilot: Returns "workspace://file-id#image/png"
# In graphs: Returns "data:image/png;base64,..."
```
### Execution Context
`store_media_file()` requires an `ExecutionContext` with:
- `graph_exec_id` - Required for temp file location
- `user_id` - Required for workspace access
- `workspace_id` - Optional; enables workspace features
- `session_id` - Optional; for session scoping in CoPilot
---
## Responsibility Boundaries
### Virus Scanning
| Component | Scans? | Notes |
|-----------|--------|-------|
| `store_media_file()` | ✅ Yes | Scans **all** content before writing to local disk |
| `WorkspaceManager.write_file()` | ✅ Yes | Scans content before persisting |
**Scanning happens at:**
1. `store_media_file()` — scans everything it downloads/decodes
2. `WorkspaceManager.write_file()` — scans before persistence
Tools like `WriteWorkspaceFileTool` don't need to scan because `WorkspaceManager.write_file()` handles it.
### Persistence
| Component | Persists To | Lifecycle |
|-----------|-------------|-----------|
| `store_media_file()` | Temp dir (`/tmp/exec_file/{exec_id}/`) | Cleaned after execution |
| `WorkspaceManager` | GCS or local storage + DB | Persistent until deleted |
**Automatic cleanup:** `clean_exec_files(graph_exec_id)` removes temp files after execution completes.
---
## Decision Tree: WorkspaceManager vs store_media_file
```text
┌─────────────────────────────────────────────────────┐
│ What do you need to do with the file? │
└─────────────────────────────────────────────────────┘
┌─────────────┴─────────────┐
▼ ▼
Process in a block Store for user access
(ffmpeg, PIL, etc.) (CoPilot files, uploads)
│ │
▼ ▼
store_media_file() WorkspaceManager
with appropriate
return_format
┌──────┴──────┐
▼ ▼
"for_local_ "for_block_
processing" output"
│ │
▼ ▼
Get local Auto-saves to
path for workspace in
tools CoPilot context
Store for user access
├── write_file() ─── Upload + persist (scans internally)
├── read_file() / get_download_url() ─── Retrieve
└── list_files() / delete_file() ─── Manage
```
### Quick Reference
| Scenario | Use |
|----------|-----|
| Block needs to process a file with ffmpeg | `store_media_file(..., return_format="for_local_processing")` |
| Block needs to send file to external API | `store_media_file(..., return_format="for_external_api")` |
| Block returning a generated file | `store_media_file(..., return_format="for_block_output")` |
| API endpoint handling file upload | `WorkspaceManager.write_file()` (handles virus scanning internally) |
| API endpoint serving file download | `WorkspaceManager.get_download_url()` |
| Listing user's files | `WorkspaceManager.list_files()` |
---
## Key Files Reference
| File | Purpose |
|------|---------|
| `backend/data/workspace.py` | Database CRUD operations for UserWorkspace and UserWorkspaceFile |
| `backend/util/workspace.py` | `WorkspaceManager` class - high-level workspace API |
| `backend/util/workspace_storage.py` | Storage backends (GCS, local) and `WorkspaceStorageBackend` interface |
| `backend/util/file.py` | `store_media_file()` and media processing utilities |
| `backend/util/virus_scanner.py` | `VirusScannerService` and `scan_content_safe()` |
| `schema.prisma` | Database model definitions |
---
## Common Patterns
### Block Processing a User's File
```python
async def run(self, input_data, *, execution_context, **kwargs):
# Normalize input to local path
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# Process with local tools
output_path = process_video(local_path)
# Return (auto-saves to workspace in CoPilot)
result = await store_media_file(
file=output_path,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output", result
```
### API Upload Endpoint
```python
from backend.util.virus_scanner import VirusDetectedError, VirusScanError
async def upload_file(file: UploadFile, user_id: str, workspace_id: str):
content = await file.read()
# write_file handles virus scanning internally
manager = WorkspaceManager(user_id, workspace_id)
try:
workspace_file = await manager.write_file(
content=content,
filename=file.filename,
)
except VirusDetectedError:
raise HTTPException(status_code=400, detail="File rejected: virus detected")
except VirusScanError:
raise HTTPException(status_code=503, detail="Virus scanning unavailable")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"file_id": workspace_file.id}
```
---
## Configuration
| Setting | Purpose | Default |
|---------|---------|---------|
| `media_gcs_bucket_name` | GCS bucket for workspace storage | None (uses local) |
| `workspace_storage_dir` | Local storage directory | `{app_data}/workspaces` |
| `max_file_size_mb` | Maximum file size in MB | 100 |
| `clamav_service_enabled` | Enable virus scanning | true |
| `clamav_service_host` | ClamAV daemon host | localhost |
| `clamav_service_port` | ClamAV daemon port | 3310 |
| `clamav_max_concurrency` | Max concurrent scans to ClamAV daemon | 5 |
| `clamav_mark_failed_scans_as_clean` | If true, scan failures pass content through instead of rejecting (⚠️ security risk if ClamAV is unreachable) | false |