Compare commits

...

38 Commits

Author SHA1 Message Date
majdyz
42ccef316a fix(backend/copilot): convert absolute copilot imports to relative in sdk/service.py
Replaces all `from backend.copilot.X import Y` with `from ..X import Y`
(and inline function-body imports) to fix Pyright type collisions from
mixed absolute/relative imports. Adds `# isort: skip_file` to prevent
regression.
2026-04-15 20:27:10 +07:00
majdyz
26b5f9958b fix(frontend): add .catch to clipboard writeText in LogsTable exec ID cell
Unhandled promise rejections can occur if clipboard API fails (permissions
or browser restrictions). Silently swallow the error — admin table copy
is best-effort and the truncated ID is still visible in the cell.
2026-04-15 20:03:26 +07:00
majdyz
a281e38620 test(frontend): cover LogsTable clipboard copy for exec ID cell 2026-04-15 19:59:26 +07:00
majdyz
66cda847c7 test(frontend/copilot): add CopilotPage banner coverage for sessionDryRun
Adds integration tests that verify the test-mode banner renders only
when both sessionId and sessionDryRun are truthy, covering the new
conditional introduced in the PR to prevent stale-preference banners.
2026-04-15 19:52:44 +07:00
majdyz
f72bddb51e feat(frontend): copy full execution ID to clipboard on cell click
Clicking the truncated execution ID cell on the platform costs page now
copies the full graph_exec_id to the clipboard. Display remains as the
first 8 characters. A cursor-pointer class and title tooltip (showing
the full ID on hover) signal the cell is clickable.
2026-04-15 19:49:58 +07:00
majdyz
a96517ccb5 test(frontend): cover graph_exec_id filter apply path in PlatformCostContent 2026-04-15 19:45:36 +07:00
majdyz
9493c55108 Merge remote-tracking branch 'origin/dev' into fix/copilot-model-toggle-styling 2026-04-15 19:35:35 +07:00
Zamil Majdy
d23ca824ad fix(copilot): set session_id on mode-switch T1 to enable --resume on subsequent SDK turns (#12795)
## Why

When a user switches from **baseline** (fast) mode to **SDK**
(extended_thinking) mode mid-session, every subsequent SDK turn started
fresh with no memory of prior conversation.

Root cause: two complementary bugs on mode-switch T1 (first SDK turn
after baseline turns):
1. `session_id` was gated on `not has_history`. On mode-switch T1,
`has_history=True` (prior baseline turns in DB) so no `session_id` was
set. The CLI generated a random ID and could not upload the session file
under a predictable path → `--resume` failed on every following SDK
turn.
2. Even if `session_id` were set, the upload guard `(not has_history or
state.use_resume)` would block the session file upload on mode-switch T1
(`has_history=True`, `use_resume=False`), so the next turn still cannot
`--resume`.

Together these caused every SDK turn to re-inject the full compressed
history, causing model confusion (proactive tool calls, forgetting
context) observed in session `8237a27b-45d0-4688-af20-c185379e926f`.

## What

- **`service.py`**: Change `elif not has_history:` → `else:` for the
`session_id` assignment — set it whenever `--resume` is not active.
Covers T1 fresh, mode-switch T1 (`has_history=True` but no CLI session
exists), and T2+ fallback turns where restore failed.
- **`service.py` retry path**: Replace `not has_history` with
`"session_id" in sdk_options_kwargs` as the discriminator, so
mode-switch T1 retries also keep `session_id` while T2+ retries (where
`restore_cli_session` put a file on disk) correctly remove it to avoid
"Session ID already in use".
- **`service.py` upload guard**: Remove `and not skip_transcript_upload`
and `and (not has_history or state.use_resume)` from the
`upload_cli_session` guard. The CLI session file is independent of the
JSONL transcript; and upload must run on mode-switch T1 so the next turn
can `--resume`. `upload_cli_session` silently skips when the file is
absent, so unconditional upload is always safe.

## How

| Scenario | Before | After |
|---|---|---|
| T1 fresh (`has_history=False`) | `session_id` set ✓ | `session_id` set
✓ |
| Mode-switch T1 (`has_history=True`, no CLI session) |  not set —
**bug** | `session_id` set ✓ |
| T2+ with `--resume` | `resume` set ✓ | `resume` set ✓ |
| T2+ retry after `--resume` failed | `session_id` removed ✓ |
`session_id` removed ✓ |
| Mode-switch T1 retry | `session_id` removed  | `session_id` kept ✓ |
| Upload on mode-switch T1 |  blocked by guard — **bug** | uploaded ✓ |

7 new unit tests in `TestSdkSessionIdSelection` document all session_id
cases.
6 new tests in `mode_switch_context_test.py` cover transcript bridging
for both fast→SDK and SDK→fast switches.

## Checklist

- [x] I have read the contributing guidelines
- [x] My changes are covered by tests
- [x] `poetry run format` passes

---------

Co-authored-by: Zamil Majdy <zamilmajdy@gmail.com>
2026-04-15 19:03:18 +07:00
majdyz
4ab5c64c09 chore: resolve merge conflicts with dev 2026-04-15 19:02:35 +07:00
Zamil Majdy
227c60abd3 fix(backend/copilot): idempotency guard + frontend dedup fix for duplicate messages (#12788)
## Why

After merging #12782 to dev, a k8s rolling deployment triggered
infrastructure-level POST retries — nginx detected the old pod's
connection reset mid-stream and resent the same POST to a new pod. Both
pods independently saved the user message and ran the executor,
producing duplicate entries in the DB (seq 159, 161, 163) and a
duplicate response in the chat. The model saw the same question 3× in
its context window and spent its response commenting on that instead of
answering.

Two compounding issues:
1. **No backend idempotency**: `append_and_save_message` saves
unconditionally — k8s/nginx retries silently produce duplicate turns.
2. **Frontend dedup cleared after success**:
`lastSubmittedMsgRef.current = null` after every completed turn wipes
the dedup guard, so any rapid re-submit of the same text (from a stalled
UI or user double-click) slips through.

## What

**Backend** — Redis idempotency gate in `stream_chat_post`:
- Before saving the user message, compute `sha256(session_id +
message)[:16]` and `SET NX ex=30` in Redis
- If key already exists → duplicate: return empty SSE (`StreamFinish +
[DONE]`) immediately, skip save + executor enqueue
- User messages only (`is_user_message=True`); system/assistant messages
bypass the check

**Frontend** — Keep `lastSubmittedMsgRef` populated after success:
- Remove `lastSubmittedMsgRef.current = null` on stream complete
- `getSendSuppressionReason` already has a two-condition check: `ref ===
text AND lastUserMsg === text` — so legitimate re-asks (after a
different question was answered) still work; only rapid re-sends of the
exact same text while it's still the last user message are blocked

## How

- 30 s Redis TTL covers infrastructure retry windows (k8s SIGTERM →
connection reset → ingress retry typically < 5 s)
- Empty SSE response is well-formed (StreamFinish + [DONE]) — frontend
AI SDK marks the turn complete without rendering a ghost message
- Frontend ref kept live means: submit "foo" → success → submit "foo"
again instantly → suppressed. Submit "foo" → success → submit "bar" →
proceeds (different text updates the ref).

## Tests

- 3 new backend route tests: duplicate blocked, first POST proceeds,
non-user messages bypass
- 5 new frontend `getSendSuppressionReason` unit tests: fresh ref,
reconnecting, duplicate suppressed, different-turn re-ask allowed,
different text allowed

## Checklist

- [x] I have read the [AutoGPT Contributing
Guide](https://github.com/Significant-Gravitas/AutoGPT/blob/master/CONTRIBUTING.md)
- [x] I have performed a self-review of my code
- [x] I have added tests that prove the fix is effective
- [x] I have run `poetry run format` and `pnpm format` + `pnpm lint`
2026-04-15 18:54:59 +07:00
Zamil Majdy
d0592d63a6 test(frontend): extract resolveSessionDryRun to helper + add coverage
- Extract session dry_run resolution logic from useChatSession.ts into
  resolveSessionDryRun() helper in helpers.ts for testability
- Add 6 unit tests covering null/undefined/non-200/false/missing/true
  branches of resolveSessionDryRun in __tests__/helpers.test.ts
- Remove unused isDryRun destructure from CopilotPage.tsx (now only
  sessionDryRun is used for the session-scoped test mode banner)
- Fix patch coverage gap: new sessionDryRun logic is now fully covered
2026-04-15 18:51:43 +07:00
Zamil Majdy
fa064aa4f1 fix(copilot): only show test-mode banner when session metadata confirms dry_run=true 2026-04-15 18:43:37 +07:00
Zamil Majdy
6bdfadf903 fix(frontend): rename executionIdFilter/Input to executionIDFilter/Input
Fully-capitalize the ID acronym in symbol names per repo convention
(graphID, useBackendAPI pattern). No functional change.
2026-04-15 18:42:57 +07:00
majdyz
737aa20f80 fix(copilot): show mode/model toggles in existing sessions when not streaming 2026-04-15 18:42:11 +07:00
majdyz
5f79164c53 fix(copilot): derive test-mode banner from session metadata, hide dry-run toggle on active sessions 2026-04-15 18:40:04 +07:00
majdyz
ff65f58ba9 fix(copilot): match ModelToggleButton inactive state to DryRunToggleButton pattern (transparent, hover-only) 2026-04-15 18:06:31 +07:00
majdyz
cc89f245ce test(frontend): add Clear button coverage for execution ID filter reset 2026-04-15 18:05:22 +07:00
majdyz
8202f48e46 test(backend/frontend): add coverage for graph_exec_id filter and session-aware toggle visibility
Backend: add tests for graph_exec_id param in _build_prisma_where,
_build_raw_where, get_platform_cost_dashboard, get_platform_cost_logs,
and get_platform_cost_logs_for_export.

Frontend: add tests for Execution ID filter input in PlatformCostContent
and hasSession-based hiding of mode/model toggles in ChatInput.
2026-04-15 18:03:36 +07:00
majdyz
43e8159822 fix(copilot): simplify toggle visibility — hide mode/model on session, always show dry-run; icon-only for inactive model state 2026-04-15 17:59:53 +07:00
majdyz
25e34829bc Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into fix/copilot-model-toggle-styling 2026-04-15 17:58:50 +07:00
majdyz
6a091a17d2 fix(frontend): update openapi.json with graph_exec_id query param
Backend added graph_exec_id filter to platform cost routes; sync
the exported openapi.json so CI schema check passes.
2026-04-15 17:56:23 +07:00
Ubbe
0284614df0 fix(copilot): abort SSE stream and disconnect backend listeners on session switch (#12766)
## Summary

Fixes stream disconnection bugs where the UI shows "running" with no
output when users switch between copilot chat sessions. The root cause
is that the old SSE fetch is not aborted and backend XREAD listeners
keep running until timeout when switching sessions.

### Changes

**Frontend (`useCopilotStream.ts`, `helpers.ts`)**
- Call `sdkStop()` on session switch to abort the in-flight SSE fetch
from the old session's transport
- Fire-and-forget `DELETE` to new backend disconnect endpoint so
server-side listeners release immediately
- Store `resumeStream` and `sdkStop` in refs to fix stale closure bugs
in:
- Wake re-sync visibility handler (could call stale `resumeStream` after
tab sleep)
  - Reconnect timer callback (could target wrong session's transport)
- Resume effect (captured stale `resumeStream` during rapid session
switches)

**Backend (`stream_registry.py`, `routes.py`)**
- Add `disconnect_all_listeners(session_id)` to stream registry —
iterates active listener tasks, cancels any matching the session
- Add `DELETE /sessions/{session_id}/stream` endpoint — auth-protected,
calls `disconnect_all_listeners`, returns 204

### Why

Reported by multiple team members: when using Autopilot for anything
serious, the frontend loses the SSE connection — particularly when
switching between conversations. The backend completes fine (refreshing
shows full output), but the UI gets stuck showing "running". This is the
worst UX bug we have right now because real users will never know to
refresh.

### How to test

1. Start a long-running autopilot task (e.g., "build a snake game")
2. While it's streaming, switch to a different chat session
3. Switch back — the UI should correctly show the completed output or
resume the stream
4. Verify no "stuck running" state

## Test plan

- [ ] Manual: switch sessions during active stream — no stuck "running"
state
- [ ] Manual: background tab for >30s during stream, return — wake
re-sync works
- [ ] Manual: trigger reconnect (kill network briefly) — reconnects to
correct session
- [ ] Verify: `pnpm lint`, `pnpm types`, `poetry run lint` all pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: majdyz <zamil.majdy@agpt.co>
2026-04-15 09:50:19 +00:00
majdyz
5cfb6ffdaa fix(copilot): use read-only aria-label on disabled mode/model toggles
When readOnly=true, aria-label now describes the current state
("Advanced model active for this session") instead of announcing
an unavailable switch action, per WCAG accessible name guidance.
2026-04-15 16:36:49 +07:00
majdyz
f49a9f728c fix(copilot): address review - raw_where graph_exec_id, disabled hover styles, readOnly tests
- Add graph_exec_id to _build_raw_where so percentile/bucket SQL
  queries respect execution ID filter (was only Prisma queries)
- Add disabled:hover:bg-* to ModelToggleButton and ModeToggleButton
  so hover styles don't fire when button is disabled/readOnly
- Add readOnly test cases to ModelToggleButton.test.tsx covering
  disabled attribute, no-op click, and session-locked tooltip
2026-04-15 16:35:40 +07:00
majdyz
53925d2e2b fix(frontend): hide mode/model toggles during session, show as read-only when non-default 2026-04-15 16:15:22 +07:00
majdyz
aa2d2d7371 feat(frontend): add execution ID filter to platform cost admin page 2026-04-15 16:00:25 +07:00
majdyz
661fffe133 fix(copilot): align ModelToggleButton styling with ModeToggleButton pattern
Standard state now shows a background (neutral-100) and 'Standard' label,
matching the ModeToggleButton where both states are always visually distinct
with a colored background and text label.
2026-04-15 15:47:22 +07:00
Zamil Majdy
f835674498 feat(copilot): standard/advanced model toggle with Opus rate-limit multiplier (#12786)
## Why

Users have different task complexity needs. Sonnet is fast and cheap for
most queries; Opus is more capable for hard reasoning tasks. Exposing
this as a simple toggle gives users control without requiring
infrastructure complexity.

Opus costs 5× more than Sonnet per Anthropic pricing ($15/$75 vs $3/$15
per M tokens). Rather than adding a separate entitlement gate, the
rate-limit multiplier (5×) ensures Opus turns deplete the daily/weekly
quota proportionally faster — users self-limit via their existing
budget.

## What

- **Standard/Advanced model toggle** in the chat input toolbar (sky-blue
star icon, label only when active — matches the simulation
DryRunToggleButton pattern but visually distinct)
- **`CopilotLlmModel = Literal["standard", "advanced"]`** —
model-agnostic tier names (not tied to Anthropic model names)
- **Backend model resolution**: `"advanced"` → `claude-opus-4-6`,
`"standard"` → `config.model` (currently Sonnet)
- **Rate-limit multiplier**: Opus turns count as 5× in Redis token
counters (daily + weekly limits). Does **not** affect `PlatformCostLog`
or `cost_usd` — those use real API-reported values
- **localStorage persistence** via `Key.COPILOT_MODEL` so preference
survives page refresh
- **`claude_agent_max_budget_usd`** reduced from $15 to $10

## How

### Backend
- `CopilotLlmModel` type added to `config.py`, imported in
routes/executor/service
- `stream_chat_completion_sdk` accepts `model: CopilotLlmModel | None`
- Model tier resolved early in the SDK path; `_normalize_model_name`
strips the OpenRouter provider prefix
- `model_cost_multiplier` (1.0 or 5.0) computed from final resolved
model name, passed to `persist_and_record_usage` → `record_token_usage`
(Redis only)
- No separate LD flag needed — rate limit is the gate

### Frontend
- `ModelToggleButton` component: sky-blue, star icon, "Advanced" label
when active
- `copilotModel` state in `useCopilotUIStore` with localStorage
hydration
- `copilotModelRef` pattern in `useCopilotStream` (avoids recreating
`DefaultChatTransport`)
- Toggle gated behind `showModeToggle && !isStreaming` in `ChatInput`

## Checklist
- [x] Tests added/updated (ModelToggleButton.test.tsx,
service_helpers_test.py, token_tracking_test.py)
- [x] Rate-limit multiplier only affects Redis counters, not cost
tracking
- [x] No new LD flag needed
2026-04-15 15:37:11 +07:00
Zamil Majdy
da18f372f7 feat(backend/copilot): add for_agent_generation flag to find_block (#12787)
## Why
When the agent generator LLM builds a graph, it may need to look up
schema details for graph-only blocks like `AgentInputBlock`,
`AgentOutputBlock`, or `OrchestratorBlock`. These blocks are correctly
hidden from regular CoPilot `find_block` results (they can't run
standalone), but that same filter was also preventing the LLM from
discovering them when composing an agent graph.

## What
Added a `for_agent_generation: bool = False` parameter to
`FindBlockTool`.

## How
- `for_agent_generation=false` (default): existing behaviour unchanged —
graph-only blocks are filtered from both UUID lookups and text search
results.
- `for_agent_generation=true`: bypasses `COPILOT_EXCLUDED_BLOCK_TYPES` /
`COPILOT_EXCLUDED_BLOCK_IDS` so the LLM can find and inspect schemas for
INPUT, OUTPUT, ORCHESTRATOR, WEBHOOK, etc. blocks when building agent
JSON.
- MCP_TOOL blocks are still excluded even with
`for_agent_generation=true` (they go through `run_mcp_tool`, not
`find_block`).

## Checklist
- [x] No new dependencies
- [x] Backward compatible (default `false` preserves existing behaviour)
- [x] No frontend changes
2026-04-15 14:57:17 +07:00
Zamil Majdy
d82ecac363 fix(backend/copilot): null-safe token accumulation for OpenRouter null cache fields (#12789)
## Why
OpenRouter occasionally returns `null` (not `0`) for
`cache_read_input_tokens` and `cache_creation_input_tokens` on the
initial streaming event, before real token counts are available.
Python's `dict.get(key, 0)` only falls back to `0` when the key is
**missing** — when the key exists with a `null` value, `.get(key, 0)`
returns `None`. This causes `TypeError: unsupported operand type(s) for
+=: 'int' and 'NoneType'` in the usage accumulator on the first
streaming chunk from OpenRouter models.

## What
- Replace `.get(key, 0)` with `.get(key) or 0` for all four token fields
in `_run_stream_attempt`
- Add `TestTokenUsageNullSafety` unit tests in `service_helpers_test.py`

## How
Minimal targeted fix — only the four `+=` accumulation lines changed. No
behaviour change for Anthropic-native models (they never emit null
values).

## Checklist
- [x] Tests cover null event, real event, absent keys, and multi-turn
accumulation
- [x] No behaviour change for Anthropic-native models
- [x] No API changes
2026-04-15 14:50:34 +07:00
Zamil Majdy
8a2e2365f7 fix(backend/executor): charge per LLM iteration and per tool call in OrchestratorBlock (#12735)
### Why / What / How

**Why:** The OrchestratorBlock in agent mode makes multiple LLM calls in
a single node execution (one per iteration of the tool-calling loop),
but the executor was only charging the user once per run via
`_charge_usage`. Tools spawned by the orchestrator also bypassed
`_charge_usage` entirely — they execute via `on_node_execution()`
directly without going through the main execution queue, producing free
internal block executions.

**What:**
1. Charge `base_cost * (llm_call_count - 1)` extra credits after the
orchestrator block completes — covers the additional iterations beyond
the first (which is already paid for upfront).
2. Charge user credits for tools executed inside the orchestrator, the
same way queue-driven node executions are charged.

**How:**

**1. Per-iteration LLM charging**
- New `Block.extra_runtime_cost(execution_stats)` virtual method
(default returns `0`)
- `OrchestratorBlock` overrides it to return `max(0, llm_call_count -
1)`
- New `resolve_block_cost` free function in `billing.py` centralises the
block-lookup + cost-calculation pattern (used by both `charge_usage` and
`charge_extra_runtime_cost`)
- New `billing.charge_extra_runtime_cost(node_exec, extra_count)`
function that debits `base_cost * min(extra_count,
_MAX_EXTRA_RUNTIME_COST)` via `spend_credits()`, running synchronously
in a thread-pool worker
- After `_on_node_execution` completes with COMPLETED status,
`on_node_execution` calls `charge_extra_runtime_cost` if
`extra_runtime_cost > 0` and not a dry run
- `InsufficientBalanceError` from post-hoc charging is treated as a
billing leak: logged at ERROR with `billing_leak: True` structured
fields, user is notified via `_handle_insufficient_funds_notif`, but the
run status stays COMPLETED (work already done)

**2. Tool execution charging**
- New public async `ExecutionProcessor.charge_node_usage(node_exec)`
wrapper around `charge_usage` (with `execution_count=0` to avoid
inflating execution-tier counters); also calls `_handle_low_balance`
internally
- `OrchestratorBlock._execute_single_tool_with_manager` calls
`charge_node_usage` after successful tool execution (skipped for dry
runs and failed/cancelled tool runs)
- Tool cost is added to the orchestrator's `extra_cost` so it shows up
in graph stats display
- `InsufficientBalanceError` from tool charging is re-raised (not
downgraded to a tool error) in all three execution paths:
`_execute_single_tool_with_manager`, `_agent_mode_tool_executor`, and
`_execute_tools_sdk_mode`

**3. Billing module extraction**
- All billing logic extracted from `ExecutionProcessor` into
`backend/executor/billing.py` as free functions — keeps `manager.py` and
`service.py` focused on orchestration
- `ExecutionProcessor` retains thin delegation methods
(`charge_node_usage`, `charge_extra_runtime_cost`) for backward
compatibility with blocks that call them

**4. Structured error signalling**
- Tool error detection replaced brittle `text.startswith("Tool execution
failed:")` string check with a structured `_is_error` boolean field on
the tool response dict

### Changes

- `backend/blocks/_base.py`: Add
`Block.extra_runtime_cost(execution_stats) -> int` virtual method
(default `0`)
- `backend/blocks/orchestrator.py`: Override `extra_runtime_cost`; add
tool charging in `_execute_single_tool_with_manager`; add
`InsufficientBalanceError` re-raise carve-outs in all three execution
paths; replace string-prefix error detection with `_is_error` flag
- `backend/executor/billing.py` (new): Free functions
`resolve_block_cost`, `charge_usage`, `charge_extra_runtime_cost`,
`charge_node_usage`, `handle_post_execution_billing`,
`clear_insufficient_funds_notifications` — extracted from
`ExecutionProcessor`
- `backend/executor/manager.py`: Thin delegation to `billing.*`; remove
~500 lines of billing methods from `ExecutionProcessor`
- `backend/data/credit.py`: Update lazy import source from `manager` to
`billing`
- `backend/blocks/test/test_orchestrator.py`: Add `charge_node_usage`
mock + assertion
- `backend/blocks/test/test_orchestrator_dynamic_fields.py`: Add
`charge_node_usage` async mock
- `backend/blocks/test/test_orchestrator_responses_api.py`: Add
`charge_node_usage` async mock
- `backend/blocks/test/test_orchestrator_per_iteration_cost.py`: New
test file — `extra_runtime_cost` hook, `charge_extra_runtime_cost` math
(positive/zero/negative/capped/zero-cost/block-not-found/IBE),
`charge_node_usage` delegation, `on_node_execution` gate conditions
(COMPLETED/FAILED/zero-charges/dry-run/IBE), tool charging guards
(dry-run/failed/cancelled/IBE propagation)

### Checklist

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Run `poetry run pytest
backend/blocks/test/test_orchestrator_per_iteration_cost.py`
- [ ] Verify on dev: an OrchestratorBlock run with
`agent_mode_max_iterations=5` and 5 actual iterations is charged 5x the
base cost
  - [ ] Verify tool executions inside the orchestrator are charged

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: majdyz <majdy.zamil@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: majdyz <majdyz@users.noreply.github.com>
2026-04-15 13:46:08 +07:00
Zamil Majdy
55869d3c75 fix(backend/copilot): robust context fallback — upload gate, gap-fill, token-budget compression (#12782)
## Why

During a live production session, the copilot lost all conversation
context mid-session. The model stated \"I don't see any implementation
plan in our conversation\" despite 9 prior turns of context. Three
compounding bugs:

**Bug 1 — Self-perpetuating upload gate:** When `restore_cli_session`
fails on a T2+ turn, `state.use_resume=False`. The old gate `and (not
has_history or state.use_resume)` then skips the CLI session upload —
even though the T1 file may exist. Each turn without `use_resume` skips
upload → next turn can't restore → also skips → etc.

**Bug 2 — Blunt message-count cap on retries:** On `prompt-too-long`,
`_reduce_context` retried 3× but rebuilt the same oversized query each
time (transcript was empty, so all 3 attempts were identical). The
`max_fallback_messages` count-cap was a blunt instrument — it threw away
middle turns blindly instead of letting the compressor summarize
intelligently.

**Bug 3 — Gap-empty path returned zero context:** When a transcript
exists but no `--resume` (CLI session unavailable), and the gap is empty
(transcript is current), the code fell through to `return
current_message, False` — the model got no history at all.

## What

1. **Remove upload gate** — upload is attempted after every successful
turn; `upload_cli_session` silently skips when the file is absent.

2. **`transcript_msg_count` set on `cli_restored=False`** — enables the
gap path on the very next turn without waiting for a full upload cycle.

3. **Token-budget compression instead of message-count cap** —
`_reduce_context` now returns `target_tokens` (50K → 15K across
retries). `compress_context` decides what to drop via LLM summarize →
content truncate → middle-out delete → first/last trim. More context
preserved at any budget vs. blindly slicing the list.

4. **Fix gap-empty case** — when transcript is current but `--resume`
unavailable, fall through to full-session compression with the token
budget instead of returning no context.

5. **Transcript seeding after fallback** — after `use_resume=False` with
no stored transcript, compress DB messages to 30K tokens and serialise
as JSONL into `transcript_builder`. Next turn uses the gap path (inject
only new messages) instead of re-compressing full history. Only fires
once per broken session (`not transcript_content` guard).

6. **Seeding guard** — seeding skips when `skip_transcript_upload=True`
(avoids wasted compression work when the result won't be saved).

7. **Structured logging** — INFO/WARNING at every branch of
`_build_query_message` with path variables, context_bytes, compression
results.

## How

**Upload gate** (`sdk/service.py` finally-block): removed `and (not
has_history or state.use_resume)`; added INFO log showing
`use_resume`/`has_history` before upload.

**`transcript_msg_count`**: set from `dl.message_count` in the
`cli_restored=False` branch.

**`_build_query_message`**: `max_fallback_messages: int | None` →
`target_tokens: int | None`; gap-empty case falls through to
full-session compression rather than returning bare message.

**`_reduce_context`**: `_FALLBACK_MSG_LIMITS` → `_RETRY_TARGET_TOKENS =
(50_000, 15_000)`; returns `ReducedContext.target_tokens`.

**`_compress_messages` / `_run_compression`**: both now accept
`target_tokens: int | None` and thread it through to `compress_context`.

**Seeding block**: added `not skip_transcript_upload` guard; uses
`_SEED_TARGET_TOKENS = 30_000` so the seeded JSONL is always compact
enough to pass `validate_transcript`.

## Checklist

- [x] `poetry run format` passes
- [x] No new lint errors introduced (pre-existing pyright errors
unrelated)
- [x] Tests added for `attempt` parameter and `target_tokens` in
`_reduce_context`
2026-04-15 11:49:01 +07:00
Nicholas Tindle
142c5dbe99 fix(frontend): tighten artifact preview behavior (#12770)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 20:21:05 -05:00
Abhimanyu Yadav
b06648de8c ci(frontend): add Playwright PR smoke suite with seeded QA accounts (#12682)
### Why / What / How

This PR simplifies frontend PR validation to one Playwright E2E suite,
moves redundant page-level browser coverage into Vitest integration
tests, and switches Playwright auth to deterministic seeded QA accounts.
It also folds in the follow-up fixes that came out of review and CI:
lint cleanup, CodeQL feedback, PR-local type regressions, and the flaky
Library run helper.

The approach is:
- keep Playwright focused on real browser and cross-page flows that
integration tests cannot prove well
- keep page-level render and mocked API behavior in Vitest
- remove the old PR-vs-full Playwright split from CI and run one
deterministic PR suite instead
- seed reusable auth states for fixed QA users so the browser suite is
less flaky and faster to bootstrap

### Changes 🏗️

- Removed the workflow indirection that selected different Playwright
suites for PRs vs other events
- Standardized frontend CI on a single command: `pnpm test:e2e:no-build`
- Consolidated the PR-gating Playwright suite around these happy-path
specs:
  - `auth-happy-path.spec.ts`
  - `settings-happy-path.spec.ts`
  - `api-keys-happy-path.spec.ts`
  - `builder-happy-path.spec.ts`
  - `library-happy-path.spec.ts`
  - `marketplace-happy-path.spec.ts`
  - `publish-happy-path.spec.ts`
  - `copilot-happy-path.spec.ts`
- Added the missing browser-only confidence checks to the PR suite:
  - settings persistence across reload and re-login
  - API key create, copy, and revoke
  - schedule `Run now` from Library
  - activity dropdown visibility for a real run
  - creator dashboard verification after publish submission
- Increased Playwright CI workers from `6` to `8`
- Migrated redundant page-level browser coverage into Vitest
integration/unit tests where appropriate, including marketplace,
profile, settings, API keys, signup behavior, agent dashboard row
behavior, agent activity, and utility/auth helpers
- Seeded deterministic Playwright QA users in
`backend/test/e2e_test_data.py` and reused auth states from
`frontend/src/tests/credentials/`
- Fixed CodeQL insecure randomness feedback by replacing insecure
randomness in test auth utilities
- Fixed frontend lint issues in marketplace image rendering
- Fixed PR-local type regressions introduced during test migration
- Stabilized the Library E2E run helper to support the current Library
action states: `Setup your task`, `New task`, `Rerun task`, and `Run
now`
- Removed obsolete Playwright specs and the temporary migration planning
doc once the consolidation was complete
- Reverted unintended non-test backend source changes; only backend test
fixture changes remain in scope

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `pnpm lint`
  - [x] `pnpm types`
  - [x] `pnpm test:unit`
  - [x] `pnpm exec playwright test --list`
  - [x] `pnpm test:e2e:no-build` locally
  - [ ] PR CI green after the latest push

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

Notes:
- Current local Playwright run on this branch: `28 passed`, `0 flaky`,
`0 retries`, `3m 25s`.
- Latest Codecov report on this PR showed overall coverage `63.14% ->
63.61%` (`+0.47%`), with frontend coverage up `+2.32%` and frontend E2E
coverage up `+2.10%`.
- The backend change in this PR is limited to deterministic E2E test
data setup in `backend/test/e2e_test_data.py`.
- Playwright retries remain enabled in CI; this branch does not add
fail-on-flaky behavior.

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Zamil Majdy <majdy.zamil@gmail.com>
2026-04-14 15:54:11 +00:00
Zamil Majdy
7240dd4fb1 feat(platform/admin): enhance cost dashboard with token breakdown and averages (#12757)
## Summary
- **Token breakdown in provider table**: Added separate Input Tokens and
Output Tokens columns to the By Provider table, making it easy to see
whether costs are driven by large contexts (input) or verbose
responses/thinking (output)
- **New summary cards (8 total)**: Added Avg Cost/Request, Avg Input
Tokens, Avg Output Tokens, and Total Tokens (in/out split) cards plus
P50/P75/P95/P99 cost percentile cards at the top of the dashboard for
at-a-glance cost analysis
- **Cost distribution histogram**: Added a cost distribution section
showing request count across configurable price buckets ($0–0.50,
$0.50–1, $1–2, $2–5, $5–10, $10+)
- **Per-user avg cost**: Added Avg Cost/Req column to the By User table
to identify users with unusually expensive requests
- **Backend aggregations**: Extended `PlatformCostDashboard` model with
`total_input_tokens`, `total_output_tokens`,
`avg_input_tokens_per_request`, `avg_output_tokens_per_request`,
`avg_cost_microdollars_per_request`,
`cost_p50/p75/p95/p99_microdollars`, and `cost_buckets` fields
- **Correct denominators**: Avg cost uses cost-bearing requests only;
avg token stats use token-bearing requests only — no artificial dilution
from non-cost/non-token rows

## Test plan
- [x] Verify the admin cost dashboard loads without errors at
`/admin/platform-costs`
- [x] Check that the new summary cards display correct values
- [x] Verify Input/Output Tokens columns appear in the By Provider table
- [x] Verify Avg Cost/Req column appears in the By User table
- [x] Confirm existing functionality (filters, export, rate overrides)
still works
- [x] Verify backward compatibility — new fields have defaults so old
API responses still work
2026-04-14 22:20:50 +07:00
Zamil Majdy
b4cd00bea9 dx(frontend): untrack auto-generated API client model files (#12778)
## Why
`src/app/api/__generated__/` is listed in `.gitignore` but 4 model files
were committed before that rule existed, so git kept tracking them and
they showed up in every PR that touched the API schema.

## What
Run `git rm --cached` on all 4 tracked files so the existing gitignore
rule takes effect. No gitignore content changes needed — the rule was
already correct.

## How
The `check API types` CI job only diffs `openapi.json` against the
backend's exported schema — it does not diff the generated TypeScript
models. So removing these from tracking does not break any CI check.

After this merges, `pnpm generate:api` output will be gitignored
everywhere and future API-touching PRs won't include generated model
diffs.
2026-04-14 22:19:32 +07:00
Zamil Majdy
e17914d393 perf(backend): enable cross-user prompt caching via SystemPromptPreset (#12758)
## Summary
- Use `SystemPromptPreset` with `exclude_dynamic_sections=True` in the
SDK path so the Claude Code default prompt serves as a cacheable prefix
shared across all users, reducing input token cost by ~90%
- Add `claude_agent_cross_user_prompt_cache` config field (default
`True`) to make this configurable, with fallback to raw string when
disabled
- Extract `_build_system_prompt_value()` helper for testability, with
`_SystemPromptPreset` TypedDict for proper type annotation

> **Depends on #12747** — requires SDK >=0.1.58 which adds
`SystemPromptPreset` with `exclude_dynamic_sections`. Must be merged
after #12747.

## Changes
- **`config.py`**: New `claude_agent_cross_user_prompt_cache: bool =
True` field on `ChatConfig`
- **`sdk/service.py`**: `_SystemPromptPreset` TypedDict for type safety;
`_build_system_prompt_value()` helper that constructs the preset dict or
returns the raw string; call site uses the helper
- **`sdk/service_test.py`**: Tests exercise the production
`_build_system_prompt_value()` helper directly — verifying preset dict
structure (enabled), raw string fallback (disabled), and default config
value

## How it works
The Claude Code CLI supports `SystemPromptPreset` which uses the
built-in Claude Code default prompt as a static prefix. By setting
`exclude_dynamic_sections=True`, per-user dynamic sections (working dir,
git status, auto-memory) are stripped from that prefix so it stays
identical across users and benefits from Anthropic's prompt caching. Our
custom prompt (tool notes, supplements, graphiti context) is appended
after the cacheable prefix.

## Test plan
- [x] CI passes (formatting, linting, unit tests)
- [x] Verify `_build_system_prompt_value()` returns correct preset dict
when enabled
- [x] Verify fallback to raw string when
`CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE=false`
2026-04-14 21:30:28 +07:00
Zamil Majdy
b3a58389e5 fix(copilot): baseline cost tracking and cache token display (#12762)
## Why
The baseline copilot path (OpenAI-compatible / OpenRouter) did not
record any cost when the `x-total-cost` response header was absent, even
though token counts were always available. The admin cost dashboard also
lacked cache token columns.

## What
- **`x-total-cost` header extraction**: Reads the OpenRouter cost header
per LLM call in the `finally` block (so cost is captured even when the
stream errors mid-way). Accumulated across multi-round tool-calling
turns.
- **Cache token extraction**: Extracts
`prompt_tokens_details.cached_tokens` and `cache_creation_input_tokens`
from streaming usage chunks and passes
`cache_read_tokens`/`cache_creation_tokens` through to
`persist_and_record_usage` for storage in `PlatformCostLog`.
- **Dashboard cache token display**: Adds cache read/write columns to
the Raw Logs and By User tables on the admin platform costs dashboard.
Adds `total_cache_read_tokens` and `total_cache_creation_tokens` to
`UserCostSummary`.
- **No cost estimation**: When `x-total-cost` is absent, `cost_usd` is
left as `None` and `persist_and_record_usage` records the entry under
`tracking_type="tokens"`. Token-based cost estimation was removed — the
platform dashboard already handles per-token cost display, and estimates
would introduce inaccuracy in the reported figures.

## How
- In `_baseline_llm_caller`: extract the `x-total-cost` header in the
`finally` block; accumulate to `state.cost_usd`.
- In `_BaselineStreamState`: add `turn_cache_read_tokens` /
`turn_cache_creation_tokens` counters, populated from streaming usage
chunks.
- In `persist_and_record_usage` / `record_cost_log`: pass through
`cache_read_tokens` and `cache_creation_tokens` to `PlatformCostEntry`.
- Frontend: add `total_cache_read_tokens` /
`total_cache_creation_tokens` fields to `UserCostSummary` and render
them as columns in the cost dashboard.

## Test plan
- [x] Verify baseline copilot sessions log cost when `x-total-cost`
header is present
- [x] Verify `cost_usd` stays `None` and token count is logged when
header is absent
- [x] Verify cache tokens appear in the dashboard logs table for
sessions using prompt caching
- [x] Verify the By User tab shows Cache Read and Cache Write columns
- [x] Unit tests: `test_cost_usd_extracted_from_response_header`,
`test_cost_usd_remains_none_when_header_missing`,
`test_cache_tokens_extracted_from_usage_details`
2026-04-14 21:08:31 +07:00
186 changed files with 17160 additions and 4873 deletions

View File

@@ -48,14 +48,15 @@ git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
For each changed file, determine:
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
**Priority order:**
1. Pages with new/changed data fetching or user interactions
2. Components with complex internal logic (modals, forms, wizards)
3. Hooks with non-trivial business logic
3. Shared hooks with standalone business logic when UI-level coverage is impractical
4. Pure helper functions
Skip: styling-only changes, type-only changes, config changes.
@@ -163,6 +164,7 @@ describe("LibraryPage", () => {
- Use `waitFor` when asserting side effects or state changes after interactions
- Import `fireEvent` or `userEvent` from the test-utils for interactions
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
- Keep tests focused: one behavior per test
- Use descriptive test names that read like sentences
@@ -190,9 +192,7 @@ import { http, HttpResponse } from "msw";
server.use(
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
return HttpResponse.json({
agents: [
{ id: "1", name: "Test Agent", description: "A test agent" },
],
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
});
}),
@@ -211,6 +211,7 @@ pnpm test:unit --reporter=verbose
```
If tests fail:
1. Read the error output carefully
2. Fix the test (not the source code, unless there is a genuine bug)
3. Re-run until all pass

View File

@@ -160,6 +160,7 @@ jobs:
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -288,6 +289,14 @@ jobs:
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Cache Playwright browsers
uses: actions/cache@v5
with:
path: ~/.cache/ms-playwright
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
restore-keys: |
playwright-${{ runner.os }}-
- name: Copy source maps from Docker for E2E coverage
run: |
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
@@ -299,8 +308,8 @@ jobs:
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
- name: Run Playwright E2E suite
run: pnpm test:e2e:no-build
continue-on-error: false
- name: Upload E2E coverage to Codecov

1
.gitignore vendored
View File

@@ -194,3 +194,4 @@ test.db
.next
# Implementation plans (generated by AI agents)
plans/
.claude/worktrees/

View File

@@ -0,0 +1,166 @@
{
"id": "858e2226-e047-4d19-a832-3be4a134d155",
"version": 2,
"is_active": true,
"name": "Calculator agent",
"description": "",
"instructions": null,
"recommended_schedule_cron": null,
"forked_from_id": null,
"forked_from_version": null,
"user_id": "",
"created_at": "2026-04-13T03:45:11.241Z",
"nodes": [
{
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"input_default": {
"name": "Input",
"secret": false,
"advanced": false
},
"metadata": {
"position": {
"x": -188.2244873046875,
"y": 95
}
},
"input_links": [],
"output_links": [
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
},
{
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
"input_default": {
"name": "Output",
"secret": false,
"advanced": false,
"escape_html": false
},
"metadata": {
"position": {
"x": 825.198974609375,
"y": 123.75
}
},
"input_links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
}
],
"output_links": [],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
},
{
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
"input_default": {
"b": 34,
"operation": "Add",
"round_result": false
},
"metadata": {
"position": {
"x": 323.0255126953125,
"y": 121.25
}
},
"input_links": [
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"output_links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
}
],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
}
],
"links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
},
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"sub_graphs": [],
"input_schema": {
"type": "object",
"properties": {
"Input": {
"advanced": false,
"secret": false,
"title": "Input"
}
},
"required": [
"Input"
]
},
"output_schema": {
"type": "object",
"properties": {
"Output": {
"advanced": false,
"secret": false,
"title": "Output"
}
},
"required": [
"Output"
]
},
"has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"trigger_setup_info": null,
"credentials_input_schema": {
"type": "object",
"properties": {},
"required": []
}
}

View File

@@ -43,6 +43,7 @@ async def get_cost_dashboard(
model: str | None = Query(None),
block_name: str | None = Query(None),
tracking_type: str | None = Query(None),
graph_exec_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
return await get_platform_cost_dashboard(
@@ -53,6 +54,7 @@ async def get_cost_dashboard(
model=model,
block_name=block_name,
tracking_type=tracking_type,
graph_exec_id=graph_exec_id,
)
@@ -72,6 +74,7 @@ async def get_cost_logs(
model: str | None = Query(None),
block_name: str | None = Query(None),
tracking_type: str | None = Query(None),
graph_exec_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost logs", admin_user_id)
logs, total = await get_platform_cost_logs(
@@ -84,6 +87,7 @@ async def get_cost_logs(
model=model,
block_name=block_name,
tracking_type=tracking_type,
graph_exec_id=graph_exec_id,
)
total_pages = (total + page_size - 1) // page_size
return PlatformCostLogsResponse(
@@ -117,6 +121,7 @@ async def export_cost_logs(
model: str | None = Query(None),
block_name: str | None = Query(None),
tracking_type: str | None = Query(None),
graph_exec_id: str | None = Query(None),
):
logger.info("Admin %s exporting platform cost logs", admin_user_id)
logs, truncated = await get_platform_cost_logs_for_export(
@@ -127,6 +132,7 @@ async def export_cost_logs(
model=model,
block_name=block_name,
tracking_type=tracking_type,
graph_exec_id=graph_exec_id,
)
return PlatformCostExportResponse(
logs=logs,

View File

@@ -15,9 +15,10 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -139,6 +140,11 @@ class StreamChatRequest(BaseModel):
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",
)
model: CopilotLlmModel | None = Field(
default=None,
description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. "
"If None, the server applies per-user LD targeting then falls back to config.",
)
class CreateSessionRequest(BaseModel):
@@ -376,6 +382,31 @@ async def delete_session(
return Response(status_code=204)
@router.delete(
"/sessions/{session_id}/stream",
dependencies=[Security(auth.requires_user)],
status_code=204,
)
async def disconnect_session_stream(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""Disconnect all active SSE listeners for a session.
Called by the frontend when the user switches away from a chat so the
backend releases XREAD listeners immediately rather than waiting for
the 5-10 s timeout.
"""
session = await get_chat_session(session_id, user_id)
if not session:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
await stream_registry.disconnect_all_listeners(session_id)
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
@@ -810,6 +841,9 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
@@ -838,60 +872,91 @@ async def stream_chat_post(
)
request.message += files_block
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
model=request.model,
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -899,6 +964,9 @@ async def stream_chat_post(
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
@@ -912,6 +980,12 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -923,8 +997,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
return # finally releases dedup_lock
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -953,7 +1026,6 @@ async def stream_chat_post(
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
@@ -967,7 +1039,8 @@ async def stream_chat_post(
}
},
)
break
break # finally releases dedup_lock
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -982,7 +1055,7 @@ async def stream_chat_post(
}
},
)
pass # Client disconnected - background task continues
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -997,7 +1070,10 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:

View File

@@ -133,14 +133,30 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
Returns:
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
@@ -150,7 +166,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
mock_enqueue = mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
@@ -158,9 +174,18 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
@@ -189,7 +214,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
@@ -228,7 +253,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
@@ -257,7 +282,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
# ─── Rate limit → 429 ─────────────────────────────────────────────────
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture):
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
@@ -278,7 +303,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix
assert "daily" in response.json()["detail"].lower()
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
def test_stream_chat_returns_429_on_weekly_rate_limit(
mocker: pytest_mock.MockerFixture,
):
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
@@ -301,7 +328,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi
assert "resets in" in detail
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
"""The 429 response detail should include the human-readable reset time."""
from backend.copilot.rate_limit import RateLimitExceeded
@@ -677,3 +704,279 @@ class TestStripInjectedContext:
result = _strip_injected_context(msg)
# Without a role, the helper short-circuits without touching content.
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
def test_disconnect_stream_returns_204_and_awaits_registry(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_session = MagicMock()
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=mock_session,
)
mock_disconnect = mocker.patch(
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
new_callable=AsyncMock,
return_value=2,
)
response = client.delete("/sessions/sess-1/stream")
assert response.status_code == 204
mock_disconnect.assert_awaited_once_with("sess-1")
def test_disconnect_stream_returns_404_when_session_missing(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=None,
)
mock_disconnect = mocker.patch(
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
new_callable=AsyncMock,
)
response = client.delete("/sessions/unknown-session/stream")
assert response.status_code == 404
mock_disconnect.assert_not_awaited()

View File

@@ -25,6 +25,7 @@ from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
is_credentials_field_name,
)
@@ -43,7 +44,7 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from backend.data.model import ContributorDetails
from ..data.graph import Link
@@ -420,6 +421,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
"""Return extra runtime cost to charge after this block run completes.
Called by the executor after a block finishes with COMPLETED status.
The return value is the number of additional base-cost credits to
charge beyond the single credit already collected by charge_usage
at the start of execution. Defaults to 0 (no extra charges).
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
calls within one run and should be billed per call.
"""
return 0
def __init__(
self,
id: str = "",
@@ -455,8 +469,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
@@ -474,7 +486,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -554,7 +566,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
self.execution_stats += stats
return self.execution_stats

View File

@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import InsufficientBalanceError
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
from backend.util.security import SENSITIVE_FIELD_NAMES
from backend.util.tool_call_loop import (
@@ -364,10 +365,31 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
class OrchestratorBlock(Block):
"""A block that uses a language model to orchestrate tool calls.
Supports both single-shot and iterative agent mode execution.
**InsufficientBalanceError propagation contract**: ``InsufficientBalanceError``
(IBE) must always re-raise through every ``except`` block in this class.
Swallowing IBE would let the agent loop continue with unpaid work. Every
exception handler that catches ``Exception`` includes an explicit IBE
re-raise carve-out for this reason.
"""
A block that uses a language model to orchestrate tool calls, supporting both
single-shot and iterative agent mode execution.
"""
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
"""Charge one extra runtime cost per LLM call beyond the first.
In agent mode each iteration makes one LLM call. The first is already
covered by charge_usage(); this returns the number of additional
credits so the executor can bill the remaining calls post-completion.
SDK-mode exemption: when the block runs via _execute_tools_sdk_mode,
the SDK manages its own conversation loop and only exposes aggregate
usage. We hardcode llm_call_count=1 there (the SDK does not report a
per-turn call count), so this method always returns 0 for SDK-mode
executions. Per-iteration billing does not apply to SDK mode.
"""
return max(0, execution_stats.llm_call_count - 1)
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
@@ -1077,7 +1099,10 @@ class OrchestratorBlock(Block):
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
if node_exec_result is None:
raise RuntimeError(
f"upsert_execution_input returned None for node {sink_node_id}"
)
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
@@ -1112,15 +1137,86 @@ class OrchestratorBlock(Block):
task=node_exec_future,
)
# Execute the node directly since we're in the Orchestrator context
node_exec_future.set_result(
await execution_processor.on_node_execution(
# Execute the node directly since we're in the Orchestrator context.
# Wrap in try/except so the future is always resolved, even on
# error — an unresolved Future would block anything awaiting it.
#
# on_node_execution is decorated with @async_error_logged(swallow=True),
# which catches BaseException and returns None rather than raising.
# Treat a None return as a failure: set_exception so the future
# carries an error state rather than a None result, and return an
# error response so the LLM knows the tool failed.
try:
tool_node_stats = await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
if tool_node_stats is None:
nil_err = RuntimeError(
f"on_node_execution returned None for node {sink_node_id} "
"(error was swallowed by @async_error_logged)"
)
node_exec_future.set_exception(nil_err)
resp = _create_tool_response(
tool_call.id,
"Tool execution returned no result",
responses_api=responses_api,
)
resp["_is_error"] = True
return resp
node_exec_future.set_result(tool_node_stats)
except Exception as exec_err:
node_exec_future.set_exception(exec_err)
raise
# Charge user credits AFTER successful tool execution. Tools
# spawned by the orchestrator bypass the main execution queue
# (where _charge_usage is called), so we must charge here to
# avoid free tool execution. Charging post-completion (vs.
# pre-execution) avoids billing users for failed tool calls.
# Skipped for dry runs.
#
# `error is None` intentionally excludes both Exception and
# BaseException subclasses (e.g. CancelledError) so cancelled
# or terminated tool runs are not billed.
#
# Billing errors (including non-balance exceptions) are kept
# in a separate try/except so they are never silently swallowed
# by the generic tool-error handler below.
if (
not execution_params.execution_context.dry_run
and tool_node_stats.error is None
):
try:
tool_cost, _ = await execution_processor.charge_node_usage(
node_exec_entry,
)
except InsufficientBalanceError:
# IBE must propagate — see OrchestratorBlock class docstring.
# Log the billing failure here so the discarded tool result
# is traceable before the loop aborts.
logger.warning(
"Insufficient balance charging for tool node %s after "
"successful execution; agent loop will be aborted",
sink_node_id,
)
raise
except Exception:
# Non-billing charge failures (DB outage, network, etc.)
# must NOT propagate to the outer except handler because
# the tool itself succeeded. Re-raising would mark the
# tool as failed (_is_error=True), causing the LLM to
# retry side-effectful operations. Log and continue.
logger.exception(
"Unexpected error charging for tool node %s; "
"tool execution was successful",
sink_node_id,
)
tool_cost = 0
if tool_cost > 0:
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
@@ -1133,18 +1229,26 @@ class OrchestratorBlock(Block):
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(
resp = _create_tool_response(
tool_call.id, tool_response_content, responses_api=responses_api
)
resp["_is_error"] = False
return resp
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.warning("Tool execution with manager failed: %s", e)
# Return error response
return _create_tool_response(
logger.warning("Tool execution with manager failed: %s", e, exc_info=True)
# Return a generic error to the LLM — internal exception messages
# may contain server paths, DB details, or infrastructure info.
resp = _create_tool_response(
tool_call.id,
f"Tool execution failed: {e}",
"Tool execution failed due to an internal error",
responses_api=responses_api,
)
resp["_is_error"] = True
return resp
async def _agent_mode_llm_caller(
self,
@@ -1244,13 +1348,16 @@ class OrchestratorBlock(Block):
content = str(raw_content)
else:
content = "Tool executed successfully"
tool_failed = content.startswith("Tool execution failed:")
tool_failed = result.get("_is_error", True)
return ToolCallResult(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
content=content,
is_error=tool_failed,
)
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.error("Tool execution failed: %s", e)
return ToolCallResult(
@@ -1370,9 +1477,13 @@ class OrchestratorBlock(Block):
"arguments": tc.arguments,
},
)
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
# Catch all errors (validation, network, API) so that the block
# surfaces them as user-visible output instead of crashing.
# Catch all OTHER errors (validation, network, API) so that
# the block surfaces them as user-visible output instead of
# crashing.
yield "error", str(e)
return
@@ -1450,11 +1561,14 @@ class OrchestratorBlock(Block):
text = content
else:
text = json.dumps(content)
tool_failed = text.startswith("Tool execution failed:")
tool_failed = result.get("_is_error", True)
return {
"content": [{"type": "text", "text": text}],
"isError": tool_failed,
}
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.error("SDK tool execution failed: %s", e)
return {
@@ -1733,11 +1847,15 @@ class OrchestratorBlock(Block):
await pending_task
except (asyncio.CancelledError, StopAsyncIteration):
pass
except InsufficientBalanceError:
# IBE must propagate — see class docstring. The `finally`
# block below still runs and records partial token usage.
raise
except Exception as e:
# Surface SDK errors as user-visible output instead of crashing,
# consistent with _execute_tools_agent_mode error handling.
# Don't return yet — fall through to merge_stats below so
# partial token usage is always recorded.
# Surface OTHER SDK errors as user-visible output instead
# of crashing, consistent with _execute_tools_agent_mode
# error handling. Don't return yet — fall through to
# merge_stats below so partial token usage is always recorded.
sdk_error = e
finally:
# Always record usage stats, even on error. The SDK may have

View File

@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock charge_node_usage (called after successful tool execution).
# Returns (cost, remaining_balance). Must be AsyncMock because it is
# an async method and is directly awaited in _execute_single_tool_with_manager.
# Use a non-zero cost so the merge_stats branch is exercised.
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
# Verify charge_node_usage was actually called for the successful
# tool execution — this guards against regressions where the
# post-execution tool charging is accidentally removed.
assert mock_execution_processor.charge_node_usage.call_count == 1
@pytest.mark.asyncio
async def test_orchestrator_traditional_mode_default():

View File

@@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
# Mock charge_node_usage (called after successful tool execution).
# Must be AsyncMock because it is async and is awaited in
# _execute_single_tool_with_manager — a plain MagicMock would
# return a non-awaitable tuple and TypeError out, then be
# silently swallowed by the orchestrator's catch-all.
mock_execution_processor.charge_node_usage = AsyncMock(
return_value=(0, 0)
)
async for output_name, output_value in block.run(
input_data,

View File

@@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
ep.execution_stats_lock = threading.Lock()
ns = MagicMock(error=None)
ep.on_node_execution = AsyncMock(return_value=ns)
# Mock charge_node_usage (called after successful tool execution).
# Must be AsyncMock because it is async and is awaited in
# _execute_single_tool_with_manager — a plain MagicMock would return a
# non-awaitable tuple and TypeError out, then be silently swallowed by
# the orchestrator's catch-all.
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
block, "_create_tool_node_signatures", return_value=tool_sigs

View File

@@ -103,6 +103,7 @@ _TRANSCRIPT_UPLOAD_TIMEOUT_S = 5
# MIME types that can be embedded as vision content blocks (OpenAI format).
_VISION_MIME_TYPES = frozenset({"image/png", "image/jpeg", "image/gif", "image/webp"})
# Max size for embedding images directly in the user message (20 MiB raw).
_MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024
@@ -247,6 +248,8 @@ class _BaselineStreamState:
text_started: bool = False
turn_prompt_tokens: int = 0
turn_completion_tokens: int = 0
turn_cache_read_tokens: int = 0
turn_cache_creation_tokens: int = 0
cost_usd: float | None = None
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
session_messages: list[ChatMessage] = field(default_factory=list)
@@ -294,6 +297,18 @@ async def _baseline_llm_caller(
if chunk.usage:
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
# Extract cache token details when available (OpenAI /
# OpenRouter include these in prompt_tokens_details).
ptd = getattr(chunk.usage, "prompt_tokens_details", None)
if ptd:
state.turn_cache_read_tokens += (
getattr(ptd, "cached_tokens", 0) or 0
)
# cache_creation_input_tokens is reported by some providers
# (e.g. Anthropic native) but not standard OpenAI streaming.
state.turn_cache_creation_tokens += (
getattr(ptd, "cache_creation_input_tokens", 0) or 0
)
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
@@ -1190,16 +1205,22 @@ async def stream_chat_completion_baseline(
state.turn_prompt_tokens,
state.turn_completion_tokens,
)
# Persist token usage to session and record for rate limiting.
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
# cannot break out cache_read/cache_creation weights. Users on the
# baseline path may be slightly over-counted vs the SDK path.
# When prompt_tokens_details.cached_tokens is reported, subtract
# them from prompt_tokens to get the uncached count so the cost
# breakdown stays accurate.
uncached_prompt = state.turn_prompt_tokens
if state.turn_cache_read_tokens > 0:
uncached_prompt = max(
0, state.turn_prompt_tokens - state.turn_cache_read_tokens
)
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=state.turn_prompt_tokens,
prompt_tokens=uncached_prompt,
completion_tokens=state.turn_completion_tokens,
cache_read_tokens=state.turn_cache_read_tokens,
cache_creation_tokens=state.turn_cache_creation_tokens,
log_prefix="[Baseline]",
cost_usd=state.cost_usd,
model=active_model,
@@ -1269,10 +1290,13 @@ async def stream_chat_completion_baseline(
# On GeneratorExit the client is already gone, so unreachable yields
# are harmless; on normal completion they reach the SSE stream.
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
# Report uncached prompt tokens to match what was billed — cached tokens
# are excluded so the frontend display is consistent with cost_usd.
billed_prompt = max(0, state.turn_prompt_tokens - state.turn_cache_read_tokens)
yield StreamUsage(
prompt_tokens=state.turn_prompt_tokens,
prompt_tokens=billed_prompt,
completion_tokens=state.turn_completion_tokens,
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
total_tokens=billed_prompt + state.turn_completion_tokens,
)
yield StreamFinish()

View File

@@ -769,3 +769,244 @@ class TestBaselineCostExtraction:
# response was never assigned so cost extraction must not raise
assert state.cost_usd is None
@pytest.mark.asyncio
async def test_no_cost_when_header_missing(self):
"""cost_usd remains None when x-total-cost is absent."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="anthropic/claude-sonnet-4")
mock_raw = MagicMock()
mock_raw.headers = {} # no x-total-cost
mock_stream = MagicMock()
mock_stream._response = mock_raw
mock_chunk = MagicMock()
mock_chunk.usage = MagicMock()
mock_chunk.usage.prompt_tokens = 1000
mock_chunk.usage.completion_tokens = 500
mock_chunk.usage.prompt_tokens_details = None
mock_chunk.choices = []
async def chunk_aiter():
yield mock_chunk
mock_stream.__aiter__ = lambda self: chunk_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd is None
@pytest.mark.asyncio
async def test_cache_tokens_extracted_from_usage_details(self):
"""cache tokens are extracted from prompt_tokens_details.cached_tokens."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="openai/gpt-4o")
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": "0.01"}
mock_stream = MagicMock()
mock_stream._response = mock_raw
# Create a chunk with prompt_tokens_details
mock_ptd = MagicMock()
mock_ptd.cached_tokens = 800
mock_chunk = MagicMock()
mock_chunk.usage = MagicMock()
mock_chunk.usage.prompt_tokens = 1000
mock_chunk.usage.completion_tokens = 200
mock_chunk.usage.prompt_tokens_details = mock_ptd
mock_chunk.choices = []
async def chunk_aiter():
yield mock_chunk
mock_stream.__aiter__ = lambda self: chunk_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.turn_cache_read_tokens == 800
assert state.turn_prompt_tokens == 1000
@pytest.mark.asyncio
async def test_cache_creation_tokens_extracted_from_usage_details(self):
"""cache_creation_tokens are extracted from prompt_tokens_details."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="openai/gpt-4o")
mock_raw = MagicMock()
mock_raw.headers = {"x-total-cost": "0.01"}
mock_stream = MagicMock()
mock_stream._response = mock_raw
mock_ptd = MagicMock()
mock_ptd.cached_tokens = 0
mock_ptd.cache_creation_input_tokens = 500
mock_chunk = MagicMock()
mock_chunk.usage = MagicMock()
mock_chunk.usage.prompt_tokens = 1000
mock_chunk.usage.completion_tokens = 200
mock_chunk.usage.prompt_tokens_details = mock_ptd
mock_chunk.choices = []
async def chunk_aiter():
yield mock_chunk
mock_stream.__aiter__ = lambda self: chunk_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.turn_cache_creation_tokens == 500
@pytest.mark.asyncio
async def test_token_accumulators_track_across_multiple_calls(self):
"""Token accumulators grow correctly across multiple _baseline_llm_caller calls."""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="anthropic/claude-sonnet-4")
def make_stream(prompt_tokens: int, completion_tokens: int):
mock_raw = MagicMock()
mock_raw.headers = {} # no x-total-cost
mock_stream = MagicMock()
mock_stream._response = mock_raw
mock_chunk = MagicMock()
mock_chunk.usage = MagicMock()
mock_chunk.usage.prompt_tokens = prompt_tokens
mock_chunk.usage.completion_tokens = completion_tokens
mock_chunk.usage.prompt_tokens_details = None
mock_chunk.choices = []
async def chunk_aiter():
yield mock_chunk
mock_stream.__aiter__ = lambda self: chunk_aiter()
return mock_stream
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=[
make_stream(1000, 200),
make_stream(1100, 300),
]
)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
await _baseline_llm_caller(
messages=[{"role": "user", "content": "follow up"}],
tools=[],
state=state,
)
# No x-total-cost header and empty pricing table -- cost_usd remains None
assert state.cost_usd is None
# Accumulators hold all tokens across both turns
assert state.turn_prompt_tokens == 2100
assert state.turn_completion_tokens == 500
@pytest.mark.asyncio
async def test_cost_usd_remains_none_when_header_missing(self):
"""cost_usd stays None when x-total-cost header is absent.
Token counts are still tracked; persist_and_record_usage handles
the None cost by falling back to tracking_type='tokens'.
"""
from backend.copilot.baseline.service import (
_baseline_llm_caller,
_BaselineStreamState,
)
state = _BaselineStreamState(model="anthropic/claude-sonnet-4")
mock_raw = MagicMock()
mock_raw.headers = {} # no x-total-cost
mock_stream = MagicMock()
mock_stream._response = mock_raw
mock_chunk = MagicMock()
mock_chunk.usage = MagicMock()
mock_chunk.usage.prompt_tokens = 1000
mock_chunk.usage.completion_tokens = 500
mock_chunk.usage.prompt_tokens_details = None
mock_chunk.choices = []
async def chunk_aiter():
yield mock_chunk
mock_stream.__aiter__ = lambda self: chunk_aiter()
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
with patch(
"backend.copilot.baseline.service._get_openai_client",
return_value=mock_client,
):
await _baseline_llm_caller(
messages=[{"role": "user", "content": "hi"}],
tools=[],
state=state,
)
assert state.cost_usd is None
assert state.turn_prompt_tokens == 1000
assert state.turn_completion_tokens == 500

View File

@@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
CopilotMode = Literal["fast", "extended_thinking"]
# Per-request model tier set by the frontend model toggle.
# 'standard' uses the global config default (currently Sonnet).
# 'advanced' forces the highest-capability model (currently Opus).
# None means no preference — falls through to LD per-user targeting, then config.
# Using tier names instead of model names keeps the contract model-agnostic.
CopilotLlmModel = Literal["standard", "advanced"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -163,12 +170,12 @@ class ChatConfig(BaseSettings):
"CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.",
)
claude_agent_max_budget_usd: float = Field(
default=15.0,
default=10.0,
ge=0.01,
le=1000.0,
description="Maximum spend in USD per SDK query. The CLI attempts "
"to wrap up gracefully when this budget is reached. "
"Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
"Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
"Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.",
)
claude_agent_max_thinking_tokens: int = Field(
@@ -197,6 +204,15 @@ class ChatConfig(BaseSettings):
description="Maximum number of retries for transient API errors "
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
)
claude_agent_cross_user_prompt_cache: bool = Field(
default=True,
description="Enable cross-user prompt caching via SystemPromptPreset. "
"The Claude Code default prompt becomes a cacheable prefix shared "
"across all users, and our custom prompt is appended after it. "
"Dynamic sections (working dir, git status, auto-memory) are excluded "
"from the prefix. Set to False to fall back to passing the system "
"prompt as a raw string.",
)
claude_agent_cli_path: str | None = Field(
default=None,
description="Optional explicit path to a Claude Code CLI binary. "

View File

@@ -351,6 +351,7 @@ class CoPilotProcessor:
context=entry.context,
file_ids=entry.file_ids,
mode=effective_mode,
model=entry.model,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,

View File

@@ -9,7 +9,7 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotMode
from backend.copilot.config import CopilotLlmModel, CopilotMode
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel):
mode: CopilotMode | None = None
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
model: CopilotLlmModel | None = None
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -180,6 +183,7 @@ async def enqueue_copilot_turn(
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
model: CopilotLlmModel | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -192,6 +196,7 @@ async def enqueue_copilot_turn(
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
model: Per-request model tier ('standard' or 'advanced'). None = server default.
"""
from backend.util.clients import get_async_copilot_queue
@@ -204,6 +209,7 @@ async def enqueue_copilot_turn(
context=context,
file_ids=file_ids,
mode=mode,
model=model,
)
queue_client = await get_async_copilot_queue()

View File

@@ -0,0 +1,71 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -0,0 +1,94 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -302,6 +302,7 @@ async def record_token_usage(
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
model_cost_multiplier: float = 1.0,
) -> None:
"""Record token usage for a user across all windows.
@@ -315,12 +316,17 @@ async def record_token_usage(
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
``model_cost_multiplier`` scales the final weighted total to reflect
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
so that Opus turns deplete the rate limit faster, proportional to cost.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
@@ -332,7 +338,9 @@ async def record_token_usage(
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
total = round(
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
)
if total <= 0:
return
@@ -340,11 +348,12 @@ async def record_token_usage(
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
model_cost_multiplier,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,

View File

@@ -34,9 +34,13 @@ Steps:
always inspect the current graph first so you know exactly what to change.
Avoid using `include_graph=true` with broad keyword searches, as fetching
multiple graphs at once is expensive and consumes LLM context budget.
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
2. **Discover blocks**: Call `find_block(query, include_schemas=true, for_agent_generation=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
and full input/output schemas. The `for_agent_generation=true` flag is
required to surface graph-only blocks such as AgentInputBlock,
AgentDropdownInputBlock, AgentOutputBlock, OrchestratorBlock,
and WebhookBlock and MCPToolBlock. (When running MCP tools interactively
in CoPilot outside agent generation, use `run_mcp_tool` instead.)
3. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
@@ -177,6 +181,12 @@ To compose agents using other agents as sub-agents:
### Using MCP Tools (MCPToolBlock)
> **Agent graph vs CoPilot direct execution**: This section covers embedding MCP
> tools as persistent nodes in an agent graph. When running MCP tools directly in
> CoPilot (outside agent generation), use `run_mcp_tool` instead — it handles
> server discovery and authentication interactively. Use `MCPToolBlock` here only
> when the user wants the MCP call baked into a reusable agent graph.
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)

View File

@@ -0,0 +1,555 @@
"""Tests for context fallback paths introduced in fix/copilot-transcript-resume-gate.
Scenario table
==============
| # | use_resume | transcript_msg_count | gap | target_tokens | Expected output |
|---|------------|----------------------|---------|---------------|--------------------------------------------|
| A | True | covers all | empty | None | bare message (--resume has full context) |
| B | True | stale | 2 msgs | None | gap context prepended |
| C | True | stale | 2 msgs | 50_000 | gap compressed to budget, prepended |
| D | False | 0 | N/A | None | full session compressed, prepended |
| E | False | 0 | N/A | 50_000 | full session compressed to budget |
| F | False | 2 (partial) | 2 msgs | None | full session compressed (not just gap; |
| | | | | | CLI has zero context without --resume) |
| G | False | 2 (partial) | 2 msgs | 50_000 | full session compressed to budget |
| H | False | covers all | empty | None | full session compressed |
| | | | | | (NOT bare message — the bug that was fixed)|
| I | False | covers all | empty | 50_000 | full session compressed to tight budget |
| J | False | 2 (partial) | n/a | None | exactly ONE compression call (full prior) |
Compression unit tests
=======================
| # | Input | target_tokens | Expected |
|---|----------------------|---------------|-----------------------------------------------|
| K | [] | None | ([], False) — empty guard |
| L | [1 msg] | None | ([msg], False) — single-msg guard |
| M | [2+ msgs] | None | target_tokens=None forwarded to _run_compression |
| N | [2+ msgs] | 30_000 | target_tokens=30_000 forwarded |
| O | [2+ msgs], run fails | None | returns originals, False |
"""
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import _build_query_message, _compress_messages
from backend.util.prompt import CompressResult
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(messages: list[ChatMessage]) -> ChatSession:
now = datetime.now(UTC)
return ChatSession(
session_id="test-session",
user_id="user-1",
messages=messages,
title="test",
usage=[],
started_at=now,
updated_at=now,
)
def _msgs(*pairs: tuple[str, str]) -> list[ChatMessage]:
return [ChatMessage(role=r, content=c) for r, c in pairs]
def _passthrough_compress(target_tokens=None):
"""Return a mock that passes messages through and records its call args."""
calls: list[tuple[list, int | None]] = []
async def _mock(msgs, tok=None):
calls.append((msgs, tok))
return msgs, False
_mock.calls = calls # type: ignore[attr-defined]
return _mock
# ---------------------------------------------------------------------------
# _build_query_message — scenario AJ
# ---------------------------------------------------------------------------
class TestBuildQueryMessageResume:
"""use_resume=True paths (--resume supplies history; only inject gap if stale)."""
@pytest.mark.asyncio
async def test_scenario_a_transcript_current_returns_bare_message(self):
"""Scenario A: --resume covers full context → no prefix injected."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
result, compacted = await _build_query_message(
"q2", session, use_resume=True, transcript_msg_count=2, session_id="s"
)
assert result == "q2"
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_b_stale_transcript_injects_gap(self, monkeypatch):
"""Scenario B: stale transcript → gap context prepended."""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, compacted = await _build_query_message(
"q3", session, use_resume=True, transcript_msg_count=2, session_id="s"
)
assert "<conversation_history>" in result
assert "q2" in result
assert "a2" in result
assert "Now, the user says:\nq3" in result
# q1/a1 are covered by the transcript — must NOT appear in gap context
assert "q1" not in result
@pytest.mark.asyncio
async def test_scenario_c_stale_transcript_passes_target_tokens(self, monkeypatch):
"""Scenario C: target_tokens is forwarded to _compress_messages for the gap."""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q3",
session,
use_resume=True,
transcript_msg_count=2,
session_id="s",
target_tokens=50_000,
)
assert captured == [50_000]
class TestBuildQueryMessageNoResumeNoTranscript:
"""use_resume=False, transcript_msg_count=0 — full session compressed."""
@pytest.mark.asyncio
async def test_scenario_d_full_session_compressed(self, monkeypatch):
"""Scenario D: no resume, no transcript → compress all prior messages."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, compacted = await _build_query_message(
"q2", session, use_resume=False, transcript_msg_count=0, session_id="s"
)
assert "<conversation_history>" in result
assert "q1" in result
assert "a1" in result
assert "Now, the user says:\nq2" in result
@pytest.mark.asyncio
async def test_scenario_e_passes_target_tokens_to_compression(self, monkeypatch):
"""Scenario E: target_tokens forwarded to _compress_messages."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q2",
session,
use_resume=False,
transcript_msg_count=0,
session_id="s",
target_tokens=15_000,
)
assert captured == [15_000]
class TestBuildQueryMessageNoResumeWithTranscript:
"""use_resume=False, transcript_msg_count > 0 — gap or full-session fallback."""
@pytest.mark.asyncio
async def test_scenario_f_no_resume_always_injects_full_session(self, monkeypatch):
"""Scenario F: use_resume=False with transcript_msg_count > 0 still injects
the FULL prior session — not just the gap since the transcript end.
When there is no --resume the CLI starts with zero context, so injecting
only the post-transcript gap would silently drop all transcript-covered
history. The correct fix is to always compress the full session.
"""
session = _make_session(
_msgs(
("user", "q1"), # transcript_msg_count=2 covers these
("assistant", "a1"),
("user", "q2"), # post-transcript gap starts here
("assistant", "a2"),
("user", "q3"), # current message
)
)
compressed_msgs: list[list] = []
async def _mock_compress(msgs, target_tokens=None):
compressed_msgs.append(list(msgs))
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, _ = await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=2, # transcript covers q1/a1 but no --resume
session_id="s",
)
assert "<conversation_history>" in result
# Full session must be injected — transcript-covered turns ARE included
assert "q1" in result
assert "a1" in result
assert "q2" in result
assert "a2" in result
assert "Now, the user says:\nq3" in result
# Compressed exactly once with all 4 prior messages
assert len(compressed_msgs) == 1
assert len(compressed_msgs[0]) == 4
@pytest.mark.asyncio
async def test_scenario_g_no_resume_passes_target_tokens(self, monkeypatch):
"""Scenario G: target_tokens forwarded when use_resume=False + transcript_msg_count > 0."""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
target_tokens=50_000,
)
assert captured == [50_000]
@pytest.mark.asyncio
async def test_scenario_h_no_resume_transcript_current_injects_full_session(
self, monkeypatch
):
"""Scenario H: the bug that was fixed.
Old code path: use_resume=False, transcript_msg_count covers all prior
messages → gap sub-path: gap = [] → ``return current_message, False``
→ model received ZERO context (bare message only).
New code path: use_resume=False always compresses the full prior session
regardless of transcript_msg_count — model always gets context.
"""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, _ = await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=4, # covers ALL prior → old code returned bare msg
session_id="s",
)
# NEW: must inject full session, NOT return bare message
assert result != "q3"
assert "<conversation_history>" in result
assert "q1" in result
assert "Now, the user says:\nq3" in result
@pytest.mark.asyncio
async def test_scenario_i_no_resume_target_tokens_forwarded_any_transcript_count(
self, monkeypatch
):
"""Scenario I: target_tokens forwarded even when transcript_msg_count covers all."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q2",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
target_tokens=15_000,
)
assert 15_000 in captured
@pytest.mark.asyncio
async def test_scenario_j_no_resume_single_compression_call(self, monkeypatch):
"""Scenario J: use_resume=False always makes exactly ONE compression call
(the full session), regardless of transcript coverage.
This verifies there is no two-step gap+fallback pattern for no-resume —
compression is called once with the full prior session.
"""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
call_count = 0
async def _mock_compress(msgs, target_tokens=None):
nonlocal call_count
call_count += 1
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
)
assert call_count == 1
# ---------------------------------------------------------------------------
# _compress_messages — unit tests KO
# ---------------------------------------------------------------------------
class TestCompressMessages:
@pytest.mark.asyncio
async def test_scenario_k_empty_list_returns_empty(self):
"""Scenario K: empty input → short-circuit, no compression."""
result, compacted = await _compress_messages([])
assert result == []
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_l_single_message_returns_as_is(self):
"""Scenario L: single message → short-circuit (< 2 guard)."""
msg = ChatMessage(role="user", content="hello")
result, compacted = await _compress_messages([msg])
assert result == [msg]
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_m_target_tokens_none_forwarded(self):
"""Scenario M: target_tokens=None forwarded to _run_compression."""
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
fake_result = CompressResult(
messages=[
{"role": "user", "content": "q"},
{"role": "assistant", "content": "a"},
],
token_count=10,
was_compacted=False,
original_token_count=10,
)
with patch(
"backend.copilot.sdk.service._run_compression",
new_callable=AsyncMock,
return_value=fake_result,
) as mock_run:
await _compress_messages(msgs, target_tokens=None)
mock_run.assert_awaited_once()
_, kwargs = mock_run.call_args
assert kwargs.get("target_tokens") is None
@pytest.mark.asyncio
async def test_scenario_n_explicit_target_tokens_forwarded(self):
"""Scenario N: explicit target_tokens forwarded to _run_compression."""
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
fake_result = CompressResult(
messages=[{"role": "user", "content": "summary"}],
token_count=5,
was_compacted=True,
original_token_count=50,
)
with patch(
"backend.copilot.sdk.service._run_compression",
new_callable=AsyncMock,
return_value=fake_result,
) as mock_run:
result, compacted = await _compress_messages(msgs, target_tokens=30_000)
mock_run.assert_awaited_once()
_, kwargs = mock_run.call_args
assert kwargs.get("target_tokens") == 30_000
assert compacted is True
@pytest.mark.asyncio
async def test_scenario_o_run_compression_exception_returns_originals(self):
"""Scenario O: _run_compression raises → return original messages, False."""
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
with patch(
"backend.copilot.sdk.service._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("compression timeout"),
):
result, compacted = await _compress_messages(msgs)
assert result == msgs
assert compacted is False
@pytest.mark.asyncio
async def test_compaction_messages_filtered_before_compression(self):
"""filter_compaction_messages is applied before _run_compression is called."""
# A compaction message is one with role=assistant and specific content pattern.
# We verify that only real messages reach _run_compression.
from backend.copilot.sdk.service import filter_compaction_messages
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
# filter_compaction_messages should not remove these plain messages
filtered = filter_compaction_messages(msgs)
assert len(filtered) == len(msgs)
# ---------------------------------------------------------------------------
# target_tokens threading — _retry_target_tokens values match expectations
# ---------------------------------------------------------------------------
class TestRetryTargetTokens:
def test_first_retry_uses_first_slot(self):
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
assert _RETRY_TARGET_TOKENS[0] == 50_000
def test_second_retry_uses_second_slot(self):
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
assert _RETRY_TARGET_TOKENS[1] == 15_000
def test_second_slot_smaller_than_first(self):
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
assert _RETRY_TARGET_TOKENS[1] < _RETRY_TARGET_TOKENS[0]
# ---------------------------------------------------------------------------
# Single-message session edge cases
# ---------------------------------------------------------------------------
class TestSingleMessageSessions:
@pytest.mark.asyncio
async def test_no_resume_single_message_returns_bare(self):
"""First turn (1 message): no prior history to inject."""
session = _make_session([ChatMessage(role="user", content="hello")])
result, compacted = await _build_query_message(
"hello", session, use_resume=False, transcript_msg_count=0, session_id="s"
)
assert result == "hello"
assert compacted is False
@pytest.mark.asyncio
async def test_resume_single_message_returns_bare(self):
"""First turn with resume flag: transcript is empty so no gap."""
session = _make_session([ChatMessage(role="user", content="hello")])
result, compacted = await _build_query_message(
"hello", session, use_resume=True, transcript_msg_count=0, session_id="s"
)
assert result == "hello"
assert compacted is False

View File

@@ -0,0 +1,326 @@
"""Tests for transcript context coverage when switching between fast and SDK modes.
When a user switches modes mid-session the transcript must bridge the gap so
neither the baseline nor the SDK service loses context from turns produced by
the other mode.
Cross-mode transcript flow
==========================
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
mode) read and write the same JSONL transcript store via
``backend.copilot.transcript.upload_transcript`` /
``download_transcript``.
Fast → SDK switch
-----------------
On the first SDK turn after N baseline turns:
• ``use_resume=False`` — no CLI session exists from baseline mode.
• ``transcript_msg_count > 0`` — the baseline transcript is downloaded and
validated successfully.
• ``_build_query_message`` must inject the FULL prior session (not just a
"gap" since the transcript end) because the CLI has zero context without
``--resume``.
• After our fix, ``session_id`` IS set, so the CLI writes a session file
on this turn → ``--resume`` works on T2+.
SDK → Fast switch
-----------------
On the first baseline turn after N SDK turns:
• The baseline service downloads the SDK-written transcript.
• ``_load_prior_transcript`` loads and validates it normally — the JSONL
format is identical regardless of which mode wrote it.
• ``transcript_covers_prefix=True`` → baseline sends ONLY new messages in
its LLM payload (no double-counting of SDK history).
Scenario table (SDK _build_query_message)
==========================================
| # | Scenario | use_resume | tmc | Expected query message |
|---|--------------------------------|------------|-----|---------------------------------|
| P | Fast→SDK T1 | False | 4 | full session injected |
| Q | Fast→SDK T2+ (after fix) | True | 6 | bare message only (--resume ok) |
| R | Fast→SDK T1, single baseline | False | 2 | full session injected |
| S | SDK→Fast (baseline loads ok) | N/A | N/A | transcript covers prefix=True |
"""
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import _build_query_message
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(messages: list[ChatMessage]) -> ChatSession:
now = datetime.now(UTC)
return ChatSession(
session_id="test-session",
user_id="user-1",
messages=messages,
title="test",
usage=[],
started_at=now,
updated_at=now,
)
def _msgs(*pairs: tuple[str, str]) -> list[ChatMessage]:
return [ChatMessage(role=r, content=c) for r, c in pairs]
# ---------------------------------------------------------------------------
# Scenario P — Fast → SDK T1: full session injected from baseline transcript
# ---------------------------------------------------------------------------
class TestFastToSdkModeSwitch:
"""First SDK turn after N baseline (fast) turns.
The baseline transcript exists (has been uploaded by fast mode), but
there is no CLI session file. ``_build_query_message`` must inject
the complete prior session so the model has full context.
"""
@pytest.mark.asyncio
async def test_scenario_p_full_session_injected_on_mode_switch_t1(
self, monkeypatch
):
"""Scenario P: fast→SDK T1 injects all baseline turns into the query."""
# Simulate 4 baseline messages (2 turns) followed by the first SDK turn.
session = _make_session(
_msgs(
("user", "baseline-q1"),
("assistant", "baseline-a1"),
("user", "baseline-q2"),
("assistant", "baseline-a2"),
("user", "sdk-q1"), # current SDK turn
)
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
# transcript_msg_count=4: baseline uploaded a transcript covering all
# 4 prior messages, but use_resume=False (no CLI session from baseline).
result, compacted = await _build_query_message(
"sdk-q1",
session,
use_resume=False,
transcript_msg_count=4,
session_id="s",
)
# All baseline turns must appear — none of them can be silently dropped.
assert "<conversation_history>" in result
assert "baseline-q1" in result
assert "baseline-a1" in result
assert "baseline-q2" in result
assert "baseline-a2" in result
assert "Now, the user says:\nsdk-q1" in result
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_r_single_baseline_turn_injected(self, monkeypatch):
"""Scenario R: even a single baseline turn is captured on mode-switch T1."""
session = _make_session(
_msgs(
("user", "baseline-q1"),
("assistant", "baseline-a1"),
("user", "sdk-q1"),
)
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, _ = await _build_query_message(
"sdk-q1",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
)
assert "<conversation_history>" in result
assert "baseline-q1" in result
assert "baseline-a1" in result
assert "Now, the user says:\nsdk-q1" in result
@pytest.mark.asyncio
async def test_scenario_q_sdk_t2_uses_resume_after_fix(self):
"""Scenario Q: SDK T2+ uses --resume after mode-switch T1 set session_id.
With the mode-switch fix, T1 sets session_id → CLI writes session file →
T2 restores the session → use_resume=True. _build_query_message must
return the bare message (--resume supplies context via native session).
"""
# T2: 4 baseline turns + 1 SDK turn already recorded.
session = _make_session(
_msgs(
("user", "baseline-q1"),
("assistant", "baseline-a1"),
("user", "baseline-q2"),
("assistant", "baseline-a2"),
("user", "sdk-q1"),
("assistant", "sdk-a1"),
("user", "sdk-q2"), # current SDK T2 message
)
)
# transcript_msg_count=6 covers all prior messages → no gap.
result, compacted = await _build_query_message(
"sdk-q2",
session,
use_resume=True, # T2: --resume works after T1 set session_id
transcript_msg_count=6,
session_id="s",
)
# --resume has full context — bare message only.
assert result == "sdk-q2"
assert compacted is False
@pytest.mark.asyncio
async def test_mode_switch_t1_compresses_all_baseline_turns(self, monkeypatch):
"""_compress_messages is called with ALL prior baseline messages.
There is exactly one compression call containing all 4 baseline messages
— not just the 2 post-transcript-end messages.
"""
session = _make_session(
_msgs(
("user", "baseline-q1"),
("assistant", "baseline-a1"),
("user", "baseline-q2"),
("assistant", "baseline-a2"),
("user", "sdk-q1"),
)
)
compressed_batches: list[list] = []
async def _mock_compress(msgs, target_tokens=None):
compressed_batches.append(list(msgs))
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"sdk-q1",
session,
use_resume=False,
transcript_msg_count=4,
session_id="s",
)
# Exactly one compression call, with all 4 prior messages.
assert len(compressed_batches) == 1
assert len(compressed_batches[0]) == 4
# ---------------------------------------------------------------------------
# Scenario S — SDK → Fast: baseline loads SDK-written transcript
# ---------------------------------------------------------------------------
class TestSdkToFastModeSwitch:
"""Fast mode turn after N SDK (extended_thinking) turns.
The transcript written by SDK mode uses the same JSONL format as the one
written by baseline mode (both go through ``TranscriptBuilder``).
``_load_prior_transcript`` must accept it and mark the prefix as covered.
"""
@pytest.mark.asyncio
async def test_scenario_s_baseline_loads_sdk_transcript(self):
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Build a minimal valid transcript as SDK mode would write it.
# SDK uses append_user / append_assistant on TranscriptBuilder.
builder_sdk = TranscriptBuilder()
builder_sdk.append_user(content="sdk-question")
builder_sdk.append_assistant(
content_blocks=[{"type": "text", "text": "sdk-answer"}],
model="claude-sonnet-4",
stop_reason=STOP_REASON_END_TURN,
)
sdk_transcript = builder_sdk.to_jsonl()
# Baseline session now has those 2 SDK messages + 1 new baseline message.
download = TranscriptDownload(content=sdk_transcript, message_count=2)
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3, # 2 SDK + 1 new baseline
transcript_builder=baseline_builder,
)
# Transcript is valid and covers the prefix.
assert covers is True
assert baseline_builder.entry_count == 2
@pytest.mark.asyncio
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
If SDK mode produced more turns than the transcript captured (e.g.
upload failed on one turn), the baseline rejects the stale transcript
to avoid injecting an incomplete history.
"""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
builder_sdk = TranscriptBuilder()
builder_sdk.append_user(content="sdk-question")
builder_sdk.append_assistant(
content_blocks=[{"type": "text", "text": "sdk-answer"}],
model="claude-sonnet-4",
stop_reason=STOP_REASON_END_TURN,
)
sdk_transcript = builder_sdk.to_jsonl()
# Transcript covers only 2 messages but session has 10 (many SDK turns).
download = TranscriptDownload(content=sdk_transcript, message_count=2)
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
transcript_builder=baseline_builder,
)
# Stale transcript must be rejected.
assert covers is False
assert baseline_builder.is_empty

View File

@@ -207,7 +207,7 @@ class TestConfigDefaults:
def test_max_budget_usd_default(self):
cfg = _make_config()
assert cfg.claude_agent_max_budget_usd == 15.0
assert cfg.claude_agent_max_budget_usd == 10.0
def test_max_thinking_tokens_default(self):
cfg = _make_config()

View File

@@ -6,6 +6,7 @@ import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import (
_BARE_MESSAGE_TOKEN_FLOOR,
_build_query_message,
_format_conversation_context,
)
@@ -130,6 +131,34 @@ async def test_build_query_resume_up_to_date():
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_resume_misaligned_watermark():
"""With --resume and watermark pointing at a user message, skip gap."""
# Simulates a deleted message shifting DB positions so the watermark
# lands on a user turn instead of the expected assistant turn.
session = _make_session(
[
ChatMessage(role="user", content="turn 1"),
ChatMessage(role="assistant", content="reply 1"),
ChatMessage(
role="user", content="turn 2"
), # ← watermark points here (role=user)
ChatMessage(role="assistant", content="reply 2"),
ChatMessage(role="user", content="turn 3"),
]
)
result, was_compacted = await _build_query_message(
"turn 3",
session,
use_resume=True,
transcript_msg_count=3, # prior[2].role == "user" — misaligned
session_id="test-session",
)
# Misaligned watermark → skip gap, return bare message
assert result == "turn 3"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_resume_stale_transcript():
"""With --resume and stale transcript, gap context is prepended."""
@@ -204,7 +233,7 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
)
# Mock _compress_messages to return the messages as-is
async def _mock_compress(msgs):
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
@@ -237,7 +266,7 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
]
)
async def _mock_compress(msgs):
async def _mock_compress(msgs, target_tokens=None):
return msgs, True # Simulate actual compaction
monkeypatch.setattr(
@@ -253,3 +282,85 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
session_id="test-session",
)
assert was_compacted is True
@pytest.mark.asyncio
async def test_build_query_no_resume_at_token_floor():
"""When target_tokens is at or below the floor, return bare message.
This is the final escape hatch: if the retry budget is exhausted and
even the most aggressive compression might not fit, skip history
injection entirely so the user always gets a response.
"""
session = _make_session(
[
ChatMessage(role="user", content="old question"),
ChatMessage(role="assistant", content="old answer"),
ChatMessage(role="user", content="new question"),
]
)
result, was_compacted = await _build_query_message(
"new question",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR,
)
# At the floor threshold, no history is injected
assert result == "new question"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_below_token_floor():
"""target_tokens strictly below floor also returns bare message."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR - 1,
)
assert result == "new"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_above_token_floor_compresses(monkeypatch):
"""target_tokens just above the floor still triggers compression."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR + 1,
)
# Above the floor → history is injected (not the bare message)
assert "<conversation_history>" in result
assert "Now, the user says:\nnew" in result

View File

@@ -7,6 +7,7 @@ tests will catch it immediately.
"""
import inspect
from typing import cast
import pytest
@@ -90,6 +91,39 @@ def test_agent_options_accepts_required_fields():
assert opts.cwd == "/tmp"
def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections():
"""Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces.
The production code always includes ``exclude_dynamic_sections=True`` in the preset
dict. This compat test mirrors that exact shape so any SDK version that starts
rejecting unknown keys will be caught here rather than at runtime.
"""
from claude_agent_sdk import ClaudeAgentOptions
from claude_agent_sdk.types import SystemPromptPreset
from .service import _build_system_prompt_value
# Call the production helper directly so this test is tied to the real
# dict shape rather than a hand-rolled copy.
preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True)
assert isinstance(
preset, dict
), "_build_system_prompt_value must return a dict when caching is on"
sdk_preset = cast(SystemPromptPreset, preset)
opts = ClaudeAgentOptions(system_prompt=sdk_preset)
assert opts.system_prompt == sdk_preset
def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off():
"""When cross_user_cache=False (e.g. on --resume turns), the helper must return
a plain string so the preset+resume crash is avoided."""
from .service import _build_system_prompt_value
result = _build_system_prompt_value("my prompt", cross_user_cache=False)
assert result == "my prompt", "Must return the raw string, not a preset dict"
def test_agent_options_accepts_all_our_fields():
"""Comprehensive check of every field we use in service.py."""
from claude_agent_sdk import ClaudeAgentOptions

View File

@@ -1,5 +1,6 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
# isort: skip_file
import asyncio
import base64
import json
@@ -17,7 +18,7 @@ from dataclasses import field as dataclass_field
from typing import TYPE_CHECKING, Any, NamedTuple, cast
if TYPE_CHECKING:
from backend.copilot.permissions import CopilotPermissions
from ..permissions import CopilotPermissions
from claude_agent_sdk import (
AssistantMessage,
@@ -29,16 +30,17 @@ from claude_agent_sdk import (
ToolResultBlock,
ToolUseBlock,
)
from claude_agent_sdk.types import SystemPromptPreset
from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from opentelemetry import trace as otel_trace
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
from backend.copilot.thinking_stripper import ThinkingStripper
from backend.copilot.transcript import (
from ..context import get_workspace_manager
from ..permissions import apply_tool_permissions
from ..rate_limit import get_user_tier
from ..thinking_stripper import ThinkingStripper
from ..transcript import (
_run_compression,
cleanup_stale_project_dirs,
compact_transcript,
@@ -49,13 +51,13 @@ from backend.copilot.transcript import (
upload_transcript,
validate_transcript,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from ..transcript_builder import TranscriptBuilder
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from ..config import ChatConfig, CopilotMode
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
from ..constants import (
COPILOT_ERROR_PREFIX,
COPILOT_RETRYABLE_ERROR_PREFIX,
@@ -131,6 +133,11 @@ _MAX_STREAM_ATTEMPTS = 3
# self-correct. The limit is generous to allow recovery attempts.
_EMPTY_TOOL_CALL_LIMIT = 5
# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet
# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus
# turns deplete quota proportionally faster.
_OPUS_COST_MULTIPLIER = 5.0
# User-facing error shown when the empty-tool-call circuit breaker trips.
_CIRCUIT_BREAKER_ERROR_MSG = (
"AutoPilot was unable to complete the tool call "
@@ -260,6 +267,11 @@ class ReducedContext(NamedTuple):
resume_file: str | None
transcript_lost: bool
tried_compaction: bool
# Token budget for history compression on the DB-message fallback path.
# None means "use model-aware default". Halved on each retry so
# compress_context applies progressively more aggressive reduction
# (LLM summarize → content truncate → middle-out delete → first/last trim).
target_tokens: int | None = None
@dataclass
@@ -304,6 +316,10 @@ class _RetryState:
adapter: SDKResponseAdapter
transcript_builder: TranscriptBuilder
usage: _TokenUsage
# Token budget for history compression on retries (DB-message fallback path).
# None = model-aware default. Halved each retry for progressively more
# aggressive compression (LLM summarize → truncate → middle-out → trim).
target_tokens: int | None = None
@dataclass
@@ -335,12 +351,34 @@ class _StreamContext:
lock: AsyncClusterLock
# Per-retry token budgets for the no-transcript (use_resume=False) path.
# When there is no CLI native session to --resume, context is built from DB
# messages via _format_conversation_context. For large sessions this text
# can exceed the model context window; each retry halves the token budget so
# compress_context applies progressively more aggressive reduction:
# LLM summarize → content truncate → middle-out delete → first/last trim.
# Index 0 = first retry, 1 = second retry; last value applies beyond that.
_RETRY_TARGET_TOKENS: tuple[int, ...] = (50_000, 15_000)
# Below this token budget the model context is so tight that injecting any
# conversation history would likely exceed the limit regardless of content.
# _build_query_message returns the bare message when target_tokens falls to
# or below this floor, giving the user a response instead of a hard error.
_BARE_MESSAGE_TOKEN_FLOOR: int = 5_000
# Tight token budget for seeding the transcript builder on turns where no
# CLI native session exists. Kept below _RETRY_TARGET_TOKENS[0] so the
# seeded JSONL upload stays compact and future gap injections are small.
_SEED_TARGET_TOKENS: int = 30_000
async def _reduce_context(
transcript_content: str,
tried_compaction: bool,
session_id: str,
sdk_cwd: str,
log_prefix: str,
attempt: int = 1,
) -> ReducedContext:
"""Prepare reduced context for a retry attempt.
@@ -348,9 +386,19 @@ async def _reduce_context(
On subsequent retries (or if compaction fails), drops the transcript
entirely so the query is rebuilt from DB messages only.
`transcript_lost` is True when the transcript was dropped (caller
should set `skip_transcript_upload`).
When no transcript is available (use_resume=False fallback path), returns
a decreasing ``target_tokens`` budget so ``compress_context`` applies
progressively more aggressive reduction (LLM summarize → content truncate
→ middle-out delete → first/last trim). The budget applies in
``_build_query_message`` and is halved on each retry.
``transcript_lost`` is True when the transcript was dropped (caller
should set ``skip_transcript_upload``).
"""
# Token budget for the DB fallback on this attempt (no-transcript path).
idx = max(0, attempt - 1)
retry_target = _RETRY_TARGET_TOKENS[min(idx, len(_RETRY_TARGET_TOKENS) - 1)]
# First retry: try compacting our transcript builder state.
# Note: the CLI native --resume file is not updated with the compacted
# content (it would require emitting CLI-native JSONL format), so the
@@ -374,9 +422,14 @@ async def _reduce_context(
return ReducedContext(tb, False, None, False, True)
logger.warning("%s Compaction failed, dropping transcript", log_prefix)
# Subsequent retry or compaction failed: drop transcript entirely
logger.warning("%s Dropping transcript, rebuilding from DB messages", log_prefix)
return ReducedContext(TranscriptBuilder(), False, None, True, True)
# Subsequent retry or compaction failed: drop transcript entirely.
# Return retry_target so the caller compresses DB messages to that budget.
logger.warning(
"%s Dropping transcript, rebuilding from DB messages" " (target_tokens=%d)",
log_prefix,
retry_target,
)
return ReducedContext(TranscriptBuilder(), False, None, True, True, retry_target)
def _append_error_marker(
@@ -627,6 +680,48 @@ def _resolve_fallback_model() -> str | None:
return _normalize_model_name(raw)
async def _resolve_model_and_multiplier(
model: "CopilotLlmModel | None",
session_id: str,
) -> tuple[str | None, float]:
"""Resolve the SDK model string and rate-limit cost multiplier for a turn.
Priority (highest first):
1. Explicit per-request ``model`` tier from the frontend toggle.
2. Global config default (``_resolve_sdk_model()``).
Returns a ``(sdk_model, cost_multiplier)`` pair.
``sdk_model`` is ``None`` when the Claude Code subscription default applies.
``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise.
"""
sdk_model = _resolve_sdk_model()
if model == "advanced":
sdk_model = _normalize_model_name("anthropic/claude-opus-4-6")
logger.info(
"[SDK] [%s] Per-request model override: advanced (%s)",
session_id[:12] if session_id else "?",
sdk_model,
)
return sdk_model, _OPUS_COST_MULTIPLIER
if model == "standard":
# Reset to config default — respects subscription mode (None = CLI default).
sdk_model = _resolve_sdk_model()
logger.info(
"[SDK] [%s] Per-request model override: standard (%s)",
session_id[:12] if session_id else "?",
sdk_model or "subscription-default",
)
return sdk_model, 1.0
# No per-request override; derive multiplier from final resolved model.
cost_multiplier = (
_OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0
)
return sdk_model, cost_multiplier
_MAX_TRANSIENT_BACKOFF_SECONDS = 30
@@ -705,6 +800,34 @@ def _is_fallback_stderr(line: str) -> bool:
return "fallback model" in line.lower()
def _build_system_prompt_value(
system_prompt: str,
cross_user_cache: bool,
) -> str | SystemPromptPreset:
"""Build the ``system_prompt`` argument for :class:`ClaudeAgentOptions`.
When *cross_user_cache* is enabled, returns a :class:`SystemPromptPreset`
dict so the Claude Code default prompt becomes a cacheable prefix shared
across all users; our custom *system_prompt* is appended after it.
When disabled (or if the SDK is too old to support ``SystemPromptPreset``),
the raw *system_prompt* string is returned unchanged.
An empty *system_prompt* is accepted: the preset dict will have
``append: ""`` which the SDK treats as no custom suffix.
"""
if cross_user_cache:
logger.debug("Using SystemPromptPreset for cross-user prompt cache")
return SystemPromptPreset(
type="preset",
preset="claude_code",
append=system_prompt,
exclude_dynamic_sections=True,
)
logger.debug("Cross-user prompt cache disabled, using raw string")
return system_prompt
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -801,6 +924,7 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
async def _compress_messages(
messages: list[ChatMessage],
target_tokens: int | None = None,
) -> tuple[list[ChatMessage], bool]:
"""Compress a list of messages if they exceed the token threshold.
@@ -809,6 +933,10 @@ async def _compress_messages(
`_compress_messages` and `compact_transcript` share this helper so
client acquisition and error handling are consistent.
``target_tokens`` sets a hard ceiling for the compressed output so
callers can enforce a tighter budget on retries. When ``None``,
``compress_context`` uses the model-aware default.
See also:
`_run_compression` — shared compression with timeout guards.
`compact_transcript` — compresses JSONL transcript entries.
@@ -832,7 +960,9 @@ async def _compress_messages(
messages_dict.append(msg_dict)
try:
result = await _run_compression(messages_dict, config.model, "[SDK]")
result = await _run_compression(
messages_dict, config.model, "[SDK]", target_tokens=target_tokens
)
except Exception as exc:
# Guard against timeouts or unexpected errors in compression —
# return the original messages so the caller can proceed without
@@ -961,44 +1091,139 @@ async def _build_query_message(
use_resume: bool,
transcript_msg_count: int,
session_id: str,
target_tokens: int | None = None,
) -> tuple[str, bool]:
"""Build the query message with appropriate context.
When ``use_resume=True``, the CLI has the full session via ``--resume``;
only a gap-fill prefix is injected when the transcript is stale.
When ``use_resume=False``, the CLI starts a fresh session with no prior
context, so the full prior session is always compressed and injected via
``_format_conversation_context``. ``compress_context`` handles size
reduction internally (LLM summarize → content truncate → middle-out delete
→ first/last trim). ``target_tokens`` decreases on each retry to force
progressively more aggressive compression when the first attempt exceeds
context limits.
Returns:
Tuple of (query_message, was_compacted).
"""
msg_count = len(session.messages)
prior = session.messages[:-1] # all turns except the current user message
logger.info(
"[SDK] [%s] Context path: use_resume=%s, transcript_msg_count=%d,"
" db_msg_count=%d, target_tokens=%s",
session_id[:8],
use_resume,
transcript_msg_count,
msg_count,
target_tokens,
)
if use_resume and transcript_msg_count > 0:
if transcript_msg_count < msg_count - 1:
gap = session.messages[transcript_msg_count:-1]
compressed, was_compressed = await _compress_messages(gap)
# Sanity-check the watermark: the last covered position should be
# an assistant turn. A user-role message here means the count is
# misaligned (e.g. a message was deleted and DB positions shifted).
# Skip the gap rather than injecting wrong context — the CLI session
# loaded via --resume still has good history.
if prior[transcript_msg_count - 1].role != "assistant":
logger.warning(
"[SDK] [%s] Watermark misaligned: prior[%d].role=%r"
" (expected 'assistant') — skipping gap to avoid"
" injecting wrong context (transcript=%d, db=%d)",
session_id[:8],
transcript_msg_count - 1,
prior[transcript_msg_count - 1].role,
transcript_msg_count,
msg_count,
)
return current_message, False
gap = prior[transcript_msg_count:]
compressed, was_compressed = await _compress_messages(gap, target_tokens)
gap_context = _format_conversation_context(compressed)
if gap_context:
logger.info(
"[SDK] Transcript stale: covers %d of %d messages, "
"gap=%d (compressed=%s)",
"gap=%d (compressed=%s), gap_context_bytes=%d",
transcript_msg_count,
msg_count,
len(gap),
was_compressed,
len(gap_context),
)
return (
f"{gap_context}\n\nNow, the user says:\n{current_message}",
was_compressed,
)
logger.warning(
"[SDK] [%s] Transcript stale: gap produced empty context"
" (%d msgs, transcript=%d/%d) — sending message without gap prefix",
session_id[:8],
len(gap),
transcript_msg_count,
msg_count,
)
else:
logger.info(
"[SDK] [%s] --resume covers full context (%d messages)",
session_id[:8],
transcript_msg_count,
)
return current_message, False
elif not use_resume and msg_count > 1:
# No --resume: the CLI starts a fresh session with no prior context.
# Injecting only the post-transcript gap would omit the transcript-covered
# prefix entirely, so always compress the full prior session here.
# compress_context handles size reduction internally (LLM summarize →
# content truncate → middle-out delete → first/last trim).
# Final escape hatch: if the token budget is at or below the floor,
# the model context is so tight that even fully compressed history
# would risk a "prompt too long" error. Return the bare message so
# the user always gets a response rather than a hard failure.
if target_tokens is not None and target_tokens <= _BARE_MESSAGE_TOKEN_FLOOR:
logger.warning(
"[SDK] [%s] target_tokens=%d at or below floor (%d) —"
" skipping history injection to guarantee response delivery"
" (session has %d messages)",
session_id[:8],
target_tokens,
_BARE_MESSAGE_TOKEN_FLOOR,
msg_count,
)
return current_message, False
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({msg_count} messages) — no transcript for --resume"
"[SDK] [%s] No --resume for %d-message session — compressing"
" full session history (pod affinity issue or first turn after"
" restore failure); target_tokens=%s",
session_id[:8],
msg_count,
target_tokens,
)
compressed, was_compressed = await _compress_messages(session.messages[:-1])
compressed, was_compressed = await _compress_messages(prior, target_tokens)
history_context = _format_conversation_context(compressed)
if history_context:
logger.info(
"[SDK] [%s] Fallback context built: compressed=%s," " context_bytes=%d",
session_id[:8],
was_compressed,
len(history_context),
)
return (
f"{history_context}\n\nNow, the user says:\n{current_message}",
was_compressed,
)
logger.warning(
"[SDK] [%s] Fallback context empty after compression"
" (%d messages) — sending message without history",
session_id[:8],
len(prior),
)
return current_message, False
@@ -1688,15 +1913,20 @@ async def _run_stream_attempt(
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
state.usage.cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
# Use `or 0` instead of a default in .get() because
# OpenRouter may include the key with a null value (e.g.
# {"cache_read_input_tokens": null}) for models that don't
# yet report cache tokens, making .get("key", 0) return
# None rather than the fallback 0.
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0
state.usage.cache_read_tokens += (
sdk_msg.usage.get("cache_read_input_tokens") or 0
)
state.usage.cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
state.usage.cache_creation_tokens += (
sdk_msg.usage.get("cache_creation_input_tokens") or 0
)
state.usage.completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
state.usage.completion_tokens += (
sdk_msg.usage.get("output_tokens") or 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, "
@@ -1922,6 +2152,48 @@ async def _run_stream_attempt(
)
async def _seed_transcript(
session: ChatSession,
transcript_builder: TranscriptBuilder,
transcript_covers_prefix: bool,
transcript_msg_count: int,
log_prefix: str,
) -> tuple[str, bool, int]:
"""Seed the transcript builder from compressed DB messages.
Called when ``use_resume=False`` and no prior transcript exists in storage
so that ``upload_transcript`` saves a compact version for future turns.
This ensures the next turn can use the full-session compression path with
the benefit of an already-compressed baseline, and a restored CLI session
on the next pod gets a usable compact base even for sessions that started
on old pods.
Returns ``(transcript_content, transcript_covers_prefix, transcript_msg_count)``
updated values — unchanged if seeding is not possible.
"""
if len(session.messages) <= 1:
return "", transcript_covers_prefix, transcript_msg_count
_prior = session.messages[:-1]
_comp, _ = await _compress_messages(_prior, _SEED_TARGET_TOKENS)
if not _comp:
return "", transcript_covers_prefix, transcript_msg_count
_seeded = _session_messages_to_transcript(_comp)
if not _seeded or not validate_transcript(_seeded):
return "", transcript_covers_prefix, transcript_msg_count
transcript_builder.load_previous(_seeded, log_prefix=log_prefix)
logger.info(
"%s Seeded transcript from %d compressed DB messages"
" for next-turn upload (seed_target_tokens=%d)",
log_prefix,
len(_comp),
_SEED_TARGET_TOKENS,
)
return _seeded, True, len(_prior)
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -1931,6 +2203,7 @@ async def stream_chat_completion_sdk(
file_ids: list[str] | None = None,
permissions: "CopilotPermissions | None" = None,
mode: CopilotMode | None = None,
model: CopilotLlmModel | None = None,
**_kwargs: Any,
) -> AsyncIterator[StreamBaseResponse]:
"""Stream chat completion using Claude Agent SDK.
@@ -1941,6 +2214,9 @@ async def stream_chat_completion_sdk(
saved to the SDK working directory for the Read tool.
mode: Accepted for signature compatibility with the baseline path.
The SDK path does not currently branch on this value.
model: Per-request model preference from the frontend toggle.
'advanced' → Claude Opus; 'standard' → global config default.
Takes priority over per-user LaunchDarkly targeting.
"""
_ = mode # SDK path ignores the requested mode.
@@ -2055,6 +2331,10 @@ async def stream_chat_completion_sdk(
turn_cache_creation_tokens = 0
turn_cost_usd: float | None = None
graphiti_enabled = False
# Defaults ensure the finally block can always reference these safely even when
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
sdk_model: str | None = None
model_cost_multiplier: float = 1.0
# Make sure there is no more code between the lock acquisition and try-block.
try:
@@ -2145,7 +2425,7 @@ async def stream_chat_completion_sdk(
# Warm context: pre-load relevant facts from Graphiti on first turn
if graphiti_enabled and user_id and len(session.messages) <= 1:
from backend.copilot.graphiti.context import fetch_warm_context
from ..graphiti.context import fetch_warm_context
warm_ctx = await fetch_warm_context(user_id, message or "")
if warm_ctx:
@@ -2193,9 +2473,20 @@ async def stream_chat_completion_sdk(
# Builder loaded but CLI native session not available.
# --resume will not be used this turn; upload after turn
# will seed the native session for the next turn.
#
# Still record transcript_msg_count so _build_query_message
# can use the transcript-aware gap path (inject only new
# messages since the transcript end) instead of compressing
# the full DB history. This avoids prompt-too-long on
# large sessions where the CLI session is temporarily
# unavailable (e.g. mixed-version rolling deployment).
transcript_msg_count = dl.message_count
logger.info(
"%s CLI session not restored — running without --resume this turn",
"%s CLI session not restored — running without"
" --resume this turn (transcript_msg_count=%d for"
" gap-aware fallback)",
log_prefix,
transcript_msg_count,
)
else:
logger.warning("%s Transcript downloaded but invalid", log_prefix)
@@ -2255,7 +2546,10 @@ async def stream_chat_completion_sdk(
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
sdk_model = _resolve_sdk_model()
# Resolve model and cost multiplier (request tier → config default).
sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier(
model, session_id
)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
@@ -2290,8 +2584,19 @@ async def stream_chat_completion_sdk(
sid,
)
# Use SystemPromptPreset for cross-user prompt caching.
# WORKAROUND: CLI 2.1.97 (sdk 0.1.58) exits code 1 when
# excludeDynamicSections=True is in the initialize request AND
# --resume is active. Disable the preset on resumed turns.
# Turn 1 still gets the preset (no --resume).
_cross_user = config.claude_agent_cross_user_prompt_cache and not use_resume
system_prompt_value = _build_system_prompt_value(
system_prompt,
cross_user_cache=_cross_user,
)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"system_prompt": system_prompt_value,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": allowed,
"disallowed_tools": disallowed,
@@ -2330,13 +2635,19 @@ async def stream_chat_completion_sdk(
# --session-id here. CLI >=2.1.97 rejects the combination of
# --session-id + --resume unless --fork-session is also given.
sdk_options_kwargs["resume"] = resume_file
elif not has_history:
# T1 only: write CLI native session to a predictable path so
# upload_cli_session() can find it after the turn completes.
# On T2+ without --resume the T1 session file already exists at
# that path; passing --session-id again would fail with
# "Session ID already in use". The upload guard also skips T2+
# no-resume turns, so --session-id provides no benefit there.
else:
# Set session_id whenever NOT resuming so the CLI writes the
# native session file to a predictable path for
# upload_cli_session() after the turn. This covers:
# • T1 fresh: no prior history, first SDK turn.
# • Mode-switch T1: has_history=True (prior baseline turns in
# DB) but no CLI session file was ever uploaded — the CLI has
# never been invoked with this session_id before.
# • T2+ without --resume (restore failed): no session file was
# restored to local storage (restore_cli_session returned
# False), so no conflict with an existing file.
# When --resume is active the session_id is already implied by
# the resume file; passing it again would be rejected by the CLI.
sdk_options_kwargs["session_id"] = session_id
# Optional explicit Claude Code CLI binary path (decouples the
# bundled SDK version from the CLI version we run — needed because
@@ -2420,6 +2731,22 @@ async def stream_chat_completion_sdk(
if attachments.hint:
query_message = f"{query_message}\n\n{attachments.hint}"
# When running without --resume and no prior transcript in storage,
# seed the transcript builder from compressed DB messages so that
# upload_transcript saves a compact version for future turns.
if not use_resume and not transcript_content and not skip_transcript_upload:
(
transcript_content,
transcript_covers_prefix,
transcript_msg_count,
) = await _seed_transcript(
session,
transcript_builder,
transcript_covers_prefix,
transcript_msg_count,
log_prefix,
)
tried_compaction = False
# Build the per-request context carrier (shared across attempts).
@@ -2502,12 +2829,14 @@ async def stream_chat_completion_sdk(
session_id,
sdk_cwd,
log_prefix,
attempt=attempt,
)
state.transcript_builder = ctx.builder
state.use_resume = ctx.use_resume
state.resume_file = ctx.resume_file
tried_compaction = ctx.tried_compaction
state.transcript_msg_count = 0
state.target_tokens = ctx.target_tokens
if ctx.transcript_lost:
skip_transcript_upload = True
@@ -2516,18 +2845,30 @@ async def stream_chat_completion_sdk(
if ctx.use_resume and ctx.resume_file:
sdk_options_kwargs_retry["resume"] = ctx.resume_file
sdk_options_kwargs_retry.pop("session_id", None)
elif not has_history:
# T1 retry: keep session_id so the CLI writes to the
# predictable path for upload_cli_session().
elif "session_id" in sdk_options_kwargs:
# Initial invocation used session_id (T1 or mode-switch
# T1): keep it so the CLI writes the session file to the
# predictable path for upload_cli_session(). Storage is
# ephemeral per invocation, so no "Session ID already in
# use" conflict occurs — no prior file was restored.
sdk_options_kwargs_retry.pop("resume", None)
sdk_options_kwargs_retry["session_id"] = session_id
else:
# T2+ retry without --resume: do not pass --session-id.
# The T1 session file already exists at that path; re-using
# the same ID would fail with "Session ID already in use".
# The upload guard skips T2+ no-resume turns anyway.
sdk_options_kwargs_retry.pop("resume", None)
sdk_options_kwargs_retry.pop("session_id", None)
# Recompute system_prompt for retry — ctx.use_resume may have
# changed (context reduction enabled --resume). CLI 2.1.97
# crashes when excludeDynamicSections=True is combined with
# --resume, so disable the cross-user preset on resumed turns.
_cross_user_retry = (
config.claude_agent_cross_user_prompt_cache and not ctx.use_resume
)
sdk_options_kwargs_retry["system_prompt"] = _build_system_prompt_value(
system_prompt, cross_user_cache=_cross_user_retry
)
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
state.query_message, state.was_compacted = await _build_query_message(
current_message,
@@ -2535,6 +2876,7 @@ async def stream_chat_completion_sdk(
state.use_resume,
state.transcript_msg_count,
session_id,
target_tokens=state.target_tokens,
)
if attachments.hint:
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
@@ -2901,8 +3243,9 @@ async def stream_chat_completion_sdk(
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
model=config.model,
model=sdk_model or config.model,
provider="anthropic",
model_cost_multiplier=model_cost_multiplier,
)
# --- Persist session messages ---
@@ -2939,7 +3282,7 @@ async def stream_chat_completion_sdk(
# --- Graphiti: ingest conversation turn for temporal memory ---
if graphiti_enabled and user_id and message and is_user_message:
from backend.copilot.graphiti.ingest import enqueue_conversation_turn
from ..graphiti.ingest import enqueue_conversation_turn
_ingest_task = asyncio.create_task(
enqueue_conversation_turn(user_id, session_id, message)
@@ -3020,6 +3363,21 @@ async def stream_chat_completion_sdk(
# the shielded inner coroutine continues running to completion so the
# upload is not lost. This is intentional and matches the pattern
# used for upload_transcript immediately above.
#
# NOTE: upload is attempted regardless of state.use_resume — even when
# this turn ran without --resume (restore failed or first T2+ on a new
# pod), the T1 session file at the expected path may still be present
# and should be re-uploaded so the next turn can resume from it.
# upload_cli_session silently skips when the file is absent, so this is
# always safe.
#
# Intentionally NOT gated on skip_transcript_upload: that flag is set
# when our custom JSONL transcript is dropped (transcript_lost=True on
# reduced-context retries) but the CLI's native session file is written
# independently. Blocking CLI upload on transcript_lost would prevent
# T1 prompt-too-long retries from uploading their valid session file,
# breaking --resume on the next pod. The ended_with_stream_error gate
# above already covers actual turn failures.
if (
config.claude_agent_use_resume
and user_id
@@ -3027,9 +3385,15 @@ async def stream_chat_completion_sdk(
and session is not None
and state is not None
and not ended_with_stream_error
and not skip_transcript_upload
and (not has_history or state.use_resume)
):
logger.info(
"%s Attempting CLI session upload"
" (use_resume=%s, has_history=%s, skip_transcript=%s)",
log_prefix,
state.use_resume,
has_history,
skip_transcript_upload,
)
try:
await asyncio.shield(
upload_cli_session(

View File

@@ -15,11 +15,14 @@ from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
from .conftest import build_test_transcript as _build_transcript
from .service import (
_RETRY_TARGET_TOKENS,
ReducedContext,
_is_prompt_too_long,
_is_tool_only_message,
_iter_sdk_messages,
_normalize_model_name,
_reduce_context,
_TokenUsage,
)
# ---------------------------------------------------------------------------
@@ -207,6 +210,24 @@ class TestReduceContext:
assert ctx.transcript_lost is True
@pytest.mark.asyncio
async def test_drop_returns_target_tokens_attempt_1(self) -> None:
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=1)
assert ctx.transcript_lost is True
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[0]
@pytest.mark.asyncio
async def test_drop_returns_target_tokens_attempt_2(self) -> None:
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=2)
assert ctx.transcript_lost is True
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[1]
@pytest.mark.asyncio
async def test_drop_clamps_attempt_beyond_limits(self) -> None:
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=99)
assert ctx.transcript_lost is True
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[-1]
# ---------------------------------------------------------------------------
# _iter_sdk_messages
@@ -331,3 +352,264 @@ class TestIsParallelContinuation:
msg = MagicMock(spec=AssistantMessage)
msg.content = [self._make_tool_block()]
assert _is_tool_only_message(msg) is True
# ---------------------------------------------------------------------------
# _normalize_model_name — used by per-request model override
# ---------------------------------------------------------------------------
class TestNormalizeModelName:
"""Unit tests for the model-name normalisation helper.
The per-request model toggle calls _normalize_model_name with either
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
'standard'). These tests verify the OpenRouter/provider-prefix stripping
that keeps the value compatible with the Claude CLI.
"""
def test_strips_anthropic_prefix(self):
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_strips_openai_prefix(self):
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
def test_strips_google_prefix(self):
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
def test_already_normalized_unchanged(self):
assert (
_normalize_model_name("claude-sonnet-4-20250514")
== "claude-sonnet-4-20250514"
)
def test_empty_string_unchanged(self):
assert _normalize_model_name("") == ""
def test_opus_model_roundtrip(self):
"""The exact string used for the 'opus' toggle strips correctly."""
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_sonnet_openrouter_model(self):
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
# ---------------------------------------------------------------------------
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
# ---------------------------------------------------------------------------
class TestTokenUsageNullSafety:
"""Verify that ResultMessage.usage dicts with null-valued cache fields
(as emitted by OpenRouter for the initial streaming event before real
token counts are available) do not crash the accumulator.
Before the fix, dict.get("cache_read_input_tokens", 0) returned None
when the key existed with a null value, causing 'int += None' TypeError.
"""
def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None:
"""Null-safe accumulation: ``or 0`` treats missing/None as zero.
Uses ``usage.get("key") or 0`` rather than ``usage.get("key", 0)``
because the latter returns ``None`` when the key exists with a null
value, which would raise ``TypeError`` on ``int += None``. This is
the intentional pattern that fixes the OpenRouter initial-stream-event
bug described in the class docstring.
"""
acc.prompt_tokens += usage.get("input_tokens") or 0
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
acc.completion_tokens += usage.get("output_tokens") or 0
def test_null_cache_tokens_do_not_crash(self):
"""OpenRouter initial event: cache keys present with null value."""
usage = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_input_tokens": None,
"cache_creation_input_tokens": None,
}
acc = _TokenUsage()
self._apply_usage(usage, acc) # must not raise TypeError
assert acc.prompt_tokens == 0
assert acc.cache_read_tokens == 0
assert acc.cache_creation_tokens == 0
assert acc.completion_tokens == 0
def test_real_cache_tokens_are_accumulated(self):
"""OpenRouter final event: real cache token counts are captured."""
usage = {
"input_tokens": 10,
"output_tokens": 349,
"cache_read_input_tokens": 16600,
"cache_creation_input_tokens": 512,
}
acc = _TokenUsage()
self._apply_usage(usage, acc)
assert acc.prompt_tokens == 10
assert acc.cache_read_tokens == 16600
assert acc.cache_creation_tokens == 512
assert acc.completion_tokens == 349
def test_absent_cache_keys_default_to_zero(self):
"""Minimal usage dict without cache keys defaults correctly."""
usage = {"input_tokens": 5, "output_tokens": 20}
acc = _TokenUsage()
self._apply_usage(usage, acc)
assert acc.prompt_tokens == 5
assert acc.cache_read_tokens == 0
assert acc.cache_creation_tokens == 0
assert acc.completion_tokens == 20
def test_multi_turn_accumulation(self):
"""Null event followed by real event: only real tokens counted."""
null_event = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_input_tokens": None,
"cache_creation_input_tokens": None,
}
real_event = {
"input_tokens": 10,
"output_tokens": 349,
"cache_read_input_tokens": 16600,
"cache_creation_input_tokens": 512,
}
acc = _TokenUsage()
self._apply_usage(null_event, acc)
self._apply_usage(real_event, acc)
assert acc.prompt_tokens == 10
assert acc.cache_read_tokens == 16600
assert acc.cache_creation_tokens == 512
assert acc.completion_tokens == 349
# ---------------------------------------------------------------------------
# session_id / resume selection logic
# ---------------------------------------------------------------------------
def _build_sdk_options(
use_resume: bool,
resume_file: str | None,
session_id: str,
) -> dict:
"""Mirror the session_id/resume selection in stream_chat_completion_sdk.
This helper encodes the exact branching so the unit tests stay in sync
with the production code without needing to invoke the full generator.
"""
kwargs: dict = {}
if use_resume and resume_file:
kwargs["resume"] = resume_file
else:
kwargs["session_id"] = session_id
return kwargs
def _build_retry_sdk_options(
initial_kwargs: dict,
ctx_use_resume: bool,
ctx_resume_file: str | None,
session_id: str,
) -> dict:
"""Mirror the retry branch in stream_chat_completion_sdk."""
retry: dict = dict(initial_kwargs)
if ctx_use_resume and ctx_resume_file:
retry["resume"] = ctx_resume_file
retry.pop("session_id", None)
elif "session_id" in initial_kwargs:
retry.pop("resume", None)
retry["session_id"] = session_id
else:
retry.pop("resume", None)
retry.pop("session_id", None)
return retry
class TestSdkSessionIdSelection:
"""Verify that session_id is set for all non-resume turns.
Regression test for the mode-switch T1 bug: when a user switches from
baseline mode (fast) to SDK mode (extended_thinking) mid-session, the
first SDK turn has has_history=True but no CLI session file. The old
code gated session_id on ``not has_history``, so mode-switch T1 never
got a session_id — the CLI used a random ID that couldn't be found on
the next turn, causing --resume to fail for the whole session.
"""
SESSION_ID = "sess-abc123"
def test_t1_fresh_sets_session_id(self):
"""T1 of a fresh session always gets session_id."""
opts = _build_sdk_options(
use_resume=False,
resume_file=None,
session_id=self.SESSION_ID,
)
assert opts.get("session_id") == self.SESSION_ID
assert "resume" not in opts
def test_mode_switch_t1_sets_session_id(self):
"""Mode-switch T1 (has_history=True, no CLI session) gets session_id.
Before the fix, the ``elif not has_history`` guard prevented this
case from setting session_id, causing all subsequent turns to run
without --resume.
"""
# Mode-switch T1: use_resume=False (no prior CLI session) and
# has_history=True (prior baseline turns in DB). The old code
# (``elif not has_history``) silently skipped this case.
opts = _build_sdk_options(
use_resume=False,
resume_file=None,
session_id=self.SESSION_ID,
)
assert opts.get("session_id") == self.SESSION_ID
assert "resume" not in opts
def test_t2_with_resume_uses_resume(self):
"""T2+ with a restored CLI session uses --resume, not session_id."""
opts = _build_sdk_options(
use_resume=True,
resume_file=self.SESSION_ID,
session_id=self.SESSION_ID,
)
assert opts.get("resume") == self.SESSION_ID
assert "session_id" not in opts
def test_t2_without_resume_sets_session_id(self):
"""T2+ when restore failed still gets session_id (no prior file on disk)."""
opts = _build_sdk_options(
use_resume=False,
resume_file=None,
session_id=self.SESSION_ID,
)
assert opts.get("session_id") == self.SESSION_ID
assert "resume" not in opts
def test_retry_keeps_session_id_for_t1(self):
"""Retry for T1 (or mode-switch T1) preserves session_id."""
initial = _build_sdk_options(False, None, self.SESSION_ID)
retry = _build_retry_sdk_options(initial, False, None, self.SESSION_ID)
assert retry.get("session_id") == self.SESSION_ID
assert "resume" not in retry
def test_retry_removes_session_id_for_t2_plus(self):
"""Retry for T2+ (initial used --resume) removes session_id to avoid conflict."""
initial = _build_sdk_options(True, self.SESSION_ID, self.SESSION_ID)
# T2+ retry where context reduction dropped --resume
retry = _build_retry_sdk_options(initial, False, None, self.SESSION_ID)
assert "session_id" not in retry
assert "resume" not in retry
def test_retry_t2_with_resume_sets_resume(self):
"""Retry that still uses --resume keeps --resume and drops session_id."""
initial = _build_sdk_options(True, self.SESSION_ID, self.SESSION_ID)
retry = _build_retry_sdk_options(
initial, True, self.SESSION_ID, self.SESSION_ID
)
assert retry.get("resume") == self.SESSION_ID
assert "session_id" not in retry

View File

@@ -8,7 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot import config as cfg_mod
from .service import (
_build_system_prompt_value,
_is_sdk_disconnect_error,
_normalize_model_name,
_prepare_file_attachments,
@@ -397,6 +400,7 @@ _CONFIG_ENV_VARS = (
"OPENAI_BASE_URL",
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
"CHAT_USE_CLAUDE_AGENT_SDK",
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
)
@@ -656,3 +660,62 @@ class TestSafeCloseSdkClient:
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
with pytest.raises(ValueError, match="invalid argument"):
await _safe_close_sdk_client(client, "[test]")
# ---------------------------------------------------------------------------
# SystemPromptPreset — cross-user prompt caching
# ---------------------------------------------------------------------------
class TestSystemPromptPreset:
"""Tests for _build_system_prompt_value — cross-user prompt caching."""
def test_preset_dict_structure_when_enabled(self):
"""When cross_user_cache is True, returns a _SystemPromptPreset dict."""
custom_prompt = "You are a helpful assistant."
result = _build_system_prompt_value(custom_prompt, cross_user_cache=True)
assert isinstance(result, dict)
assert result["type"] == "preset"
assert result["preset"] == "claude_code"
assert result["append"] == custom_prompt
assert result["exclude_dynamic_sections"] is True
def test_raw_string_when_disabled(self):
"""When cross_user_cache is False, returns the raw string."""
custom_prompt = "You are a helpful assistant."
result = _build_system_prompt_value(custom_prompt, cross_user_cache=False)
assert isinstance(result, str)
assert result == custom_prompt
def test_empty_string_with_cache_enabled(self):
"""Empty system_prompt with cross_user_cache=True produces append=''."""
result = _build_system_prompt_value("", cross_user_cache=True)
assert isinstance(result, dict)
assert result["type"] == "preset"
assert result["preset"] == "claude_code"
assert result["append"] == ""
assert result["exclude_dynamic_sections"] is True
def test_default_config_is_enabled(self, _clean_config_env):
"""The default value for claude_agent_cross_user_prompt_cache is True."""
cfg = cfg_mod.ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
assert cfg.claude_agent_cross_user_prompt_cache is True
def test_env_var_disables_cache(self, _clean_config_env, monkeypatch):
"""CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE=false disables caching."""
monkeypatch.setenv("CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE", "false")
cfg = cfg_mod.ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
assert cfg.claude_agent_cross_user_prompt_cache is False

View File

@@ -960,7 +960,7 @@ class TestRunCompression:
)
call_count = [0]
async def _compress_side_effect(*, messages, model, client):
async def _compress_side_effect(*, messages, model, client, target_tokens=None):
call_count[0] += 1
if client is not None:
# Simulate a hang that exceeds the timeout

View File

@@ -1149,3 +1149,50 @@ async def unsubscribe_from_session(
)
logger.debug(f"Successfully unsubscribed from session {session_id}")
async def disconnect_all_listeners(session_id: str) -> int:
"""Cancel every active listener task for *session_id*.
Called when the frontend switches away from a session and wants the
backend to release resources immediately rather than waiting for the
XREAD timeout.
Scope / limitations (best-effort optimisation, not a correctness primitive):
- Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request
lands on a different worker than the one serving the SSE, no listener
is cancelled here — the SSE worker still releases on its XREAD timeout.
- Session-scoped (not subscriber-scoped): cancels every active listener
for the session on this pod. In the rare case a single user opens two
SSE connections to the same session on the same pod (e.g. two tabs),
both would be torn down. Cross-pod, subscriber-scoped cancellation
would require a Redis pub/sub fan-out with per-listener tokens; that
is not implemented here because the XREAD timeout already bounds the
worst case.
Returns the number of listener tasks that were cancelled.
"""
to_cancel: list[tuple[int, asyncio.Task]] = [
(qid, task)
for qid, (sid, task) in list(_listener_sessions.items())
if sid == session_id and not task.done()
]
for qid, task in to_cancel:
_listener_sessions.pop(qid, None)
task.cancel()
cancelled = 0
for _qid, task in to_cancel:
try:
await asyncio.wait_for(task, timeout=5.0)
except asyncio.CancelledError:
cancelled += 1
except asyncio.TimeoutError:
pass
except Exception as e:
logger.error(f"Error cancelling listener for session {session_id}: {e}")
if cancelled:
logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}")
return cancelled

View File

@@ -0,0 +1,110 @@
"""Tests for disconnect_all_listeners in stream_registry."""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot import stream_registry
@pytest.fixture(autouse=True)
def _clear_listener_sessions():
stream_registry._listener_sessions.clear()
yield
stream_registry._listener_sessions.clear()
async def _sleep_forever():
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
raise
@pytest.mark.asyncio
async def test_disconnect_all_listeners_cancels_matching_session():
task_a = asyncio.create_task(_sleep_forever())
task_b = asyncio.create_task(_sleep_forever())
task_other = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-1", task_a)
stream_registry._listener_sessions[2] = ("sess-1", task_b)
stream_registry._listener_sessions[3] = ("sess-other", task_other)
try:
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 2
assert task_a.cancelled()
assert task_b.cancelled()
assert not task_other.done()
# Matching entries are removed, non-matching entries remain.
assert 1 not in stream_registry._listener_sessions
assert 2 not in stream_registry._listener_sessions
assert 3 in stream_registry._listener_sessions
finally:
task_other.cancel()
try:
await task_other
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_disconnect_all_listeners_no_match_returns_zero():
task = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-other", task)
try:
cancelled = await stream_registry.disconnect_all_listeners("sess-missing")
assert cancelled == 0
assert not task.done()
assert 1 in stream_registry._listener_sessions
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_disconnect_all_listeners_skips_already_done_tasks():
async def _noop():
return None
done_task = asyncio.create_task(_noop())
await done_task
stream_registry._listener_sessions[1] = ("sess-1", done_task)
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
# Done tasks are filtered out before cancellation.
assert cancelled == 0
@pytest.mark.asyncio
async def test_disconnect_all_listeners_empty_registry():
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 0
@pytest.mark.asyncio
async def test_disconnect_all_listeners_timeout_not_counted():
"""Tasks that don't respond to cancellation (timeout) are not counted."""
task = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-1", task)
with patch.object(
asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError)
):
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 0
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@@ -96,6 +96,7 @@ async def persist_and_record_usage(
cost_usd: float | str | None = None,
model: str | None = None,
provider: str = "open_router",
model_cost_multiplier: float = 1.0,
) -> int:
"""Persist token usage to session and record for rate limiting.
@@ -109,6 +110,9 @@ async def persist_and_record_usage(
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
provider: Cost provider name (e.g. "anthropic", "open_router").
model_cost_multiplier: Relative model cost factor for rate limiting
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
more expensive models deplete the rate limit proportionally faster.
Returns:
The computed total_tokens (prompt + completion; cache excluded).
@@ -163,6 +167,7 @@ async def persist_and_record_usage(
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
model_cost_multiplier=model_cost_multiplier,
)
except Exception as usage_err:
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)

View File

@@ -230,6 +230,7 @@ class TestRateLimitRecording:
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
model_cost_multiplier=1.0,
)
@pytest.mark.asyncio

View File

@@ -74,6 +74,15 @@ class FindBlockTool(BaseTool):
"description": "Include full input/output schemas (for agent JSON generation).",
"default": False,
},
"for_agent_generation": {
"type": "boolean",
"description": (
"Set to true when searching for blocks to use inside an agent graph "
"(e.g. AgentInputBlock, AgentOutputBlock, OrchestratorBlock). "
"Bypasses the CoPilot-only filter so graph-only blocks are visible."
),
"default": False,
},
},
"required": ["query"],
}
@@ -88,6 +97,7 @@ class FindBlockTool(BaseTool):
session: ChatSession,
query: str = "",
include_schemas: bool = False,
for_agent_generation: bool = False,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
@@ -97,6 +107,8 @@ class FindBlockTool(BaseTool):
session: Chat session
query: Search query
include_schemas: Whether to include block schemas in results
for_agent_generation: When True, bypasses the CoPilot exclusion filter
so graph-only blocks (INPUT, OUTPUT, ORCHESTRATOR, etc.) are visible.
Returns:
BlockListResponse: List of matching blocks
@@ -123,34 +135,36 @@ class FindBlockTool(BaseTool):
suggestions=["Search for an alternative block by name"],
session_id=session_id,
)
if (
is_excluded = (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
if block.block_type == BlockType.MCP_TOOL:
)
if is_excluded:
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
# exposed when building an agent graph so the LLM can inspect
# their schemas and wire them as nodes. In CoPilot direct use
# they are not executable — guide the LLM to the right tool.
if not for_agent_generation:
if block.block_type == BlockType.MCP_TOOL:
message = (
f"Block '{block.name}' (ID: {block.id}) cannot be "
"run directly in CoPilot. Use run_mcp_tool for "
"interactive MCP execution, or call find_block with "
"for_agent_generation=true to embed it in an agent graph."
)
else:
message = (
f"Block '{block.name}' (ID: {block.id}) is not available "
"in CoPilot. It can only be used within agent graphs."
)
return NoResultsResponse(
message=(
f"Block '{block.name}' (ID: {block.id}) is not "
"runnable through find_block/run_block. Use "
"run_mcp_tool instead."
),
message=message,
suggestions=[
"Use run_mcp_tool to discover and run this MCP tool",
"Search for an alternative block by name",
"Use this block in an agent graph instead",
],
session_id=session_id,
)
return NoResultsResponse(
message=(
f"Block '{block.name}' (ID: {block.id}) is not available "
"in CoPilot. It can only be used within agent graphs."
),
suggestions=[
"Search for an alternative block by name",
"Use this block in an agent graph instead",
],
session_id=session_id,
)
# Check block-level permissions — hide denied blocks entirely
perms = get_current_permissions()
@@ -221,8 +235,9 @@ class FindBlockTool(BaseTool):
if not block or block.disabled:
continue
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
# skipped in CoPilot direct use but surfaced for agent graph building.
if not for_agent_generation and (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):

View File

@@ -12,7 +12,7 @@ from .find_block import (
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from .models import BlockListResponse
from .models import BlockListResponse, NoResultsResponse
_TEST_USER_ID = "test-user-find-block"
@@ -166,6 +166,194 @@ class TestFindBlockFiltering:
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_for_agent_generation_exposes_excluded_blocks_in_search(self):
"""With for_agent_generation=True, excluded block types appear in search results."""
session = make_session(user_id=_TEST_USER_ID)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "output-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Agent Input", BlockType.INPUT)
output_block = make_mock_block(
"output-block-id", "Agent Output", BlockType.OUTPUT
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"output-block-id": output_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="agent input",
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 2
block_ids = {b.id for b in response.blocks}
assert "input-block-id" in block_ids
assert "output-block-id" in block_ids
@pytest.mark.asyncio(loop_scope="session")
async def test_mcp_tool_exposed_with_for_agent_generation_in_search(self):
"""MCP_TOOL blocks appear in search results when for_agent_generation=True."""
session = make_session(user_id=_TEST_USER_ID)
search_results = [
{"content_id": "mcp-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
standard_block = make_mock_block(
"standard-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"mcp-block-id": mcp_block,
"standard-block-id": standard_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="mcp tool",
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 2
assert any(b.id == "mcp-block-id" for b in response.blocks)
assert any(b.id == "standard-block-id" for b in response.blocks)
@pytest.mark.asyncio(loop_scope="session")
async def test_mcp_tool_excluded_without_for_agent_generation_in_search(self):
"""MCP_TOOL blocks are excluded from search in normal CoPilot mode."""
session = make_session(user_id=_TEST_USER_ID)
search_results = [
{"content_id": "mcp-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
standard_block = make_mock_block(
"standard-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"mcp-block-id": mcp_block,
"standard-block-id": standard_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="mcp tool",
for_agent_generation=False,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_for_agent_generation_exposes_excluded_ids_in_search(self):
"""With for_agent_generation=True, excluded block IDs appear in search results."""
session = make_session(user_id=_TEST_USER_ID)
orchestrator_id = next(iter(COPILOT_EXCLUDED_BLOCK_IDS))
search_results = [
{"content_id": orchestrator_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
orchestrator_block = make_mock_block(
orchestrator_id, "Orchestrator", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
orchestrator_id: orchestrator_block,
"normal-block-id": normal_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="orchestrator",
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 2
block_ids = {b.id for b in response.blocks}
assert orchestrator_id in block_ids
assert "normal-block-id" in block_ids
@pytest.mark.asyncio(loop_scope="session")
async def test_response_size_average_chars_per_block(self):
"""Measure average chars per block in the serialized response."""
@@ -549,8 +737,6 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=block_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
@pytest.mark.asyncio(loop_scope="session")
@@ -571,8 +757,6 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=block_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
assert "disabled" in response.message.lower()
@@ -592,8 +776,6 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=block_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
assert "not available" in response.message.lower()
@@ -613,7 +795,74 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=orchestrator_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
assert "not available" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_excluded_block_type_allowed_with_for_agent_generation(
self,
):
"""With for_agent_generation=True, excluded block types (INPUT) are visible."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
block = make_mock_block(block_id, "Agent Input Block", BlockType.INPUT)
with patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query=block_id,
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert response.count == 1
assert response.blocks[0].id == block_id
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_mcp_tool_exposed_with_for_agent_generation(self):
"""MCP_TOOL blocks are returned by UUID lookup when for_agent_generation=True."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
with patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query=block_id,
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].id == block_id
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_mcp_tool_excluded_without_for_agent_generation(self):
"""MCP_TOOL blocks are excluded by UUID lookup in normal CoPilot mode."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
with patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query=block_id,
for_agent_generation=False,
)
assert isinstance(response, NoResultsResponse)
assert "run_mcp_tool" in response.message

View File

@@ -1179,6 +1179,7 @@ async def _run_compression(
messages: list[dict],
model: str,
log_prefix: str,
target_tokens: int | None = None,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback.
@@ -1187,6 +1188,12 @@ async def _run_compression(
truncation-based compression which drops older messages without
summarization.
``target_tokens`` sets a hard token ceiling for the compressed output.
When ``None``, ``compress_context`` derives the limit from the model's
context window. Pass a smaller value on retries to force more aggressive
compression — the compressor will LLM-summarize, content-truncate,
middle-out delete, and first/last trim until the result fits.
A 60-second timeout prevents a hung LLM call from blocking the
retry path indefinitely. The truncation fallback also has a
30-second timeout to guard against slow tokenization on very large
@@ -1196,18 +1203,27 @@ async def _run_compression(
if client is None:
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
compress_context(
messages=messages, model=model, client=None, target_tokens=target_tokens
),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)
try:
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=client),
compress_context(
messages=messages,
model=model,
client=client,
target_tokens=target_tokens,
),
timeout=_COMPACTION_TIMEOUT_SECONDS,
)
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
compress_context(
messages=messages, model=model, client=None, target_tokens=target_tokens
),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)

View File

@@ -349,7 +349,7 @@ class UserCreditBase(ABC):
CreditTransactionType.GRANT,
CreditTransactionType.TOP_UP,
]:
from backend.executor.manager import (
from backend.executor.billing import (
clear_insufficient_funds_notifications,
)
@@ -554,7 +554,7 @@ class UserCreditBase(ABC):
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
):
# Lazy import to avoid circular dependency with executor.manager
from backend.executor.manager import (
from backend.executor.billing import (
clear_insufficient_funds_notifications,
)

View File

@@ -852,6 +852,7 @@ class NodeExecutionStats(BaseModel):
output_token_count: int = 0
cache_read_token_count: int = 0
cache_creation_token_count: int = 0
cost: int = 0
extra_cost: int = 0
extra_steps: int = 0
provider_cost: float | None = None

View File

@@ -8,6 +8,7 @@ from prisma.models import User as PrismaUser
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
from backend.util.json import SafeJson
@@ -139,7 +140,10 @@ class UserCostSummary(BaseModel):
total_cost_microdollars: int
total_input_tokens: int
total_output_tokens: int
total_cache_read_tokens: int = 0
total_cache_creation_tokens: int = 0
request_count: int
cost_bearing_request_count: int = 0
class CostLogRow(BaseModel):
@@ -161,12 +165,27 @@ class CostLogRow(BaseModel):
cache_creation_tokens: int | None = None
class CostBucket(BaseModel):
bucket: str
count: int
class PlatformCostDashboard(BaseModel):
by_provider: list[ProviderCostSummary]
by_user: list[UserCostSummary]
total_cost_microdollars: int
total_requests: int
total_users: int
total_input_tokens: int = 0
total_output_tokens: int = 0
avg_input_tokens_per_request: float = 0.0
avg_output_tokens_per_request: float = 0.0
avg_cost_microdollars_per_request: float = 0.0
cost_p50_microdollars: float = 0.0
cost_p75_microdollars: float = 0.0
cost_p95_microdollars: float = 0.0
cost_p99_microdollars: float = 0.0
cost_buckets: list[CostBucket] = []
def _si(row: dict, field: str) -> int:
@@ -196,6 +215,7 @@ def _build_prisma_where(
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
graph_exec_id: str | None = None,
) -> PlatformCostLogWhereInput:
"""Build a Prisma WhereInput for PlatformCostLog filters."""
where: PlatformCostLogWhereInput = {}
@@ -223,9 +243,78 @@ def _build_prisma_where(
if tracking_type:
where["trackingType"] = tracking_type
if graph_exec_id:
where["graphExecId"] = graph_exec_id
return where
def _build_raw_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
graph_exec_id: str | None = None,
) -> tuple[str, list]:
"""Build a parameterised WHERE clause for raw SQL queries.
Mirrors the filter logic of ``_build_prisma_where`` so there is a single
source of truth for which columns are filtered and how. The first clause
always restricts to ``cost_usd`` tracking type unless *tracking_type* is
explicitly provided by the caller.
"""
params: list = []
clauses: list[str] = []
idx = 1
# Always filter by tracking type — defaults to cost_usd for percentile /
# bucket queries that only make sense on cost-denominated rows.
tt = tracking_type if tracking_type is not None else "cost_usd"
clauses.append(f'"trackingType" = ${idx}')
params.append(tt)
idx += 1
if start is not None:
clauses.append(f'"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end is not None:
clauses.append(f'"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider is not None:
clauses.append(f'"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id is not None:
clauses.append(f'"userId" = ${idx}')
params.append(user_id)
idx += 1
if model is not None:
clauses.append(f'"model" = ${idx}')
params.append(model)
idx += 1
if block_name is not None:
clauses.append(f'LOWER("blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
if graph_exec_id is not None:
clauses.append(f'"graphExecId" = ${idx}')
params.append(graph_exec_id)
idx += 1
return (" AND ".join(clauses), params)
@cached(ttl_seconds=30)
async def get_platform_cost_dashboard(
start: datetime | None = None,
@@ -235,6 +324,7 @@ async def get_platform_cost_dashboard(
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
graph_exec_id: str | None = None,
) -> PlatformCostDashboard:
"""Aggregate platform cost logs for the admin dashboard.
@@ -251,7 +341,22 @@ async def get_platform_cost_dashboard(
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
)
# For per-user tracking-type breakdown we intentionally omit the
# tracking_type filter so cost_usd and tokens rows are always present.
# This ensures cost_bearing_request_count is correct even when the caller
# is filtering the main view by a different tracking_type.
where_no_tracking_type = _build_prisma_where(
start,
end,
provider,
user_id,
model,
block_name,
tracking_type=None,
graph_exec_id=graph_exec_id,
)
sum_fields = {
@@ -264,39 +369,159 @@ async def get_platform_cost_dashboard(
"trackingAmount": True,
}
# Run all four aggregation queries in parallel.
by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = (
await asyncio.gather(
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
# sort by total cost descending in Python after fetch.
# Build parameterised WHERE clause for the raw SQL percentile/bucket
# queries. Uses _build_raw_where so filter logic is shared with
# _build_prisma_where and only maintained in one place.
# Always force tracking_type=None here so _build_raw_where defaults to
# "cost_usd" — percentile and histogram queries only make sense on
# cost-denominated rows, regardless of what the caller is filtering.
raw_where, raw_params = _build_raw_where(
start,
end,
provider,
user_id,
model,
block_name,
tracking_type=None,
graph_exec_id=graph_exec_id,
)
# Queries that always run regardless of tracking_type filter.
common_queries = [
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
# sort by total cost descending in Python after fetch.
PrismaLog.prisma().group_by(
by=["provider", "trackingType", "model"],
where=where,
sum=sum_fields,
count=True,
),
# userId aggregation — emails fetched separately below.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
sum=sum_fields,
count=True,
),
# Per-user cost-bearing request count: group by (userId, trackingType)
# so we can compute the correct denominator for per-user avg cost.
# Uses where_no_tracking_type so cost_usd rows are always included
# even when the caller filters the main view by a different tracking_type.
PrismaLog.prisma().group_by(
by=["userId", "trackingType"],
where=where_no_tracking_type,
count=True,
),
# Distinct user count: group by userId, count groups.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
count=True,
),
# Total aggregate (filtered): group by (provider, trackingType) so we can
# compute cost-bearing and token-bearing denominators for avg stats.
PrismaLog.prisma().group_by(
by=["provider", "trackingType"],
where=where,
sum={
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
},
count=True,
),
# Percentile distribution of cost per request (respects all filters).
query_raw_with_schema(
"SELECT"
" percentile_cont(0.5) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p50,'
" percentile_cont(0.75) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p75,'
" percentile_cont(0.95) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p95,'
" percentile_cont(0.99) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p99'
' FROM {schema_prefix}"PlatformCostLog"'
f" WHERE {raw_where}",
*raw_params,
),
# Histogram buckets for cost distribution (respects all filters).
# NULL costMicrodollars is excluded explicitly to prevent such rows
# from falling through all WHEN clauses into the ELSE '$10+' bucket.
query_raw_with_schema(
"SELECT"
" CASE"
' WHEN "costMicrodollars" < 500000'
" THEN '$0-0.50'"
' WHEN "costMicrodollars" < 1000000'
" THEN '$0.50-1'"
' WHEN "costMicrodollars" < 2000000'
" THEN '$1-2'"
' WHEN "costMicrodollars" < 5000000'
" THEN '$2-5'"
' WHEN "costMicrodollars" < 10000000'
" THEN '$5-10'"
" ELSE '$10+'"
" END as bucket,"
" COUNT(*) as count"
' FROM {schema_prefix}"PlatformCostLog"'
f' WHERE {raw_where} AND "costMicrodollars" IS NOT NULL'
" GROUP BY bucket"
' ORDER BY MIN("costMicrodollars")',
*raw_params,
),
]
# Only run the unfiltered aggregate query when tracking_type is set;
# when tracking_type is None, the filtered query already contains all
# tracking types and reusing it avoids a redundant full aggregation.
if tracking_type is not None:
common_queries.append(
# Total aggregate (no tracking_type filter): used to compute
# cost_bearing_requests and token_bearing_requests denominators so
# global avg stats remain meaningful when the caller filters the
# main view by a specific tracking_type (e.g. 'tokens').
PrismaLog.prisma().group_by(
by=["provider", "trackingType", "model"],
where=where,
sum=sum_fields,
by=["provider", "trackingType"],
where=where_no_tracking_type,
sum={
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
},
count=True,
),
# userId aggregation — emails fetched separately below.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
sum=sum_fields,
count=True,
),
# Distinct user count: group by userId, count groups.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
count=True,
),
# Total aggregate: group by provider (no limit) to sum across all
# matching rows. Summed in Python to get grand totals.
PrismaLog.prisma().group_by(
by=["provider"],
where=where,
sum={"costMicrodollars": True},
count=True,
),
)
)
results = await asyncio.gather(*common_queries)
# Unpack results by name for clarity.
by_provider_groups = results[0]
by_user_groups = results[1]
by_user_tracking_groups = results[2]
total_user_groups = results[3]
total_agg_groups = results[4]
percentile_rows = results[5]
bucket_rows = results[6]
# When tracking_type is None, the filtered and unfiltered queries are
# identical — reuse total_agg_groups to avoid the extra DB round-trip.
total_agg_no_tracking_type_groups = (
results[7] if tracking_type is not None else total_agg_groups
)
# Compute token grand-totals from the unfiltered aggregate so they remain
# consistent with the avg-token stats (which also use unfiltered data).
# Using by_provider_groups here would give 0 tokens when tracking_type='cost_usd'
# because cost_usd rows carry no token data, contradicting non-zero averages.
total_input_tokens = sum(
_si(r, "inputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
total_output_tokens = sum(
_si(r, "outputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
@@ -323,6 +548,61 @@ async def get_platform_cost_dashboard(
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
total_requests = sum(_ca(r) for r in total_agg_groups)
# Extract percentile values from the raw query result.
pctl = percentile_rows[0] if percentile_rows else {}
cost_p50 = float(pctl.get("p50") or 0)
cost_p75 = float(pctl.get("p75") or 0)
cost_p95 = float(pctl.get("p95") or 0)
cost_p99 = float(pctl.get("p99") or 0)
# Build cost bucket list.
cost_buckets: list[CostBucket] = [
CostBucket(bucket=r["bucket"], count=int(r["count"])) for r in bucket_rows
]
# Avg-stat numerators and denominators are derived from the unfiltered
# aggregate so they remain meaningful when the caller filters by a specific
# tracking_type. Example: filtering by 'tokens' excludes cost_usd rows from
# total_agg_groups, so avg_cost would always be 0 if we used that; using
# total_agg_no_tracking_type_groups gives the correct cost_usd total/count.
avg_cost_total = sum(
_si(r, "costMicrodollars")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "cost_usd"
)
cost_bearing_requests = sum(
_ca(r)
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "cost_usd"
)
avg_input_total = sum(
_si(r, "inputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
avg_output_total = sum(
_si(r, "outputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
# Token-bearing request count: only rows where trackingType == "tokens".
# Token averages must use this denominator; cost_usd rows do not carry tokens.
token_bearing_requests = sum(
_ca(r)
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
# Per-user cost-bearing request count: used for per-user avg cost so the
# denominator matches the numerator (cost_usd rows only, per user).
user_cost_bearing_counts: dict[str, int] = {}
for r in by_user_tracking_groups:
if r.get("trackingType") == "cost_usd" and r.get("userId"):
uid = r["userId"]
user_cost_bearing_counts[uid] = user_cost_bearing_counts.get(uid, 0) + _ca(
r
)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
@@ -347,13 +627,38 @@ async def get_platform_cost_dashboard(
total_cost_microdollars=_si(r, "costMicrodollars"),
total_input_tokens=_si(r, "inputTokens"),
total_output_tokens=_si(r, "outputTokens"),
total_cache_read_tokens=_si(r, "cacheReadTokens"),
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
request_count=_ca(r),
cost_bearing_request_count=user_cost_bearing_counts.get(
r.get("userId") or "", 0
),
)
for r in by_user_groups
],
total_cost_microdollars=total_cost,
total_requests=total_requests,
total_users=total_users,
total_input_tokens=total_input_tokens,
total_output_tokens=total_output_tokens,
avg_input_tokens_per_request=(
avg_input_total / token_bearing_requests
if token_bearing_requests > 0
else 0.0
),
avg_output_tokens_per_request=(
avg_output_total / token_bearing_requests
if token_bearing_requests > 0
else 0.0
),
avg_cost_microdollars_per_request=(
avg_cost_total / cost_bearing_requests if cost_bearing_requests > 0 else 0.0
),
cost_p50_microdollars=cost_p50,
cost_p75_microdollars=cost_p75,
cost_p95_microdollars=cost_p95,
cost_p99_microdollars=cost_p99,
cost_buckets=cost_buckets,
)
@@ -367,12 +672,13 @@ async def get_platform_cost_logs(
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
graph_exec_id: str | None = None,
) -> tuple[list[CostLogRow], int]:
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
)
offset = (page - 1) * page_size
@@ -422,6 +728,7 @@ async def get_platform_cost_logs_for_export(
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
graph_exec_id: str | None = None,
) -> tuple[list[CostLogRow], bool]:
"""Return all matching rows up to EXPORT_MAX_ROWS.
@@ -432,7 +739,7 @@ async def get_platform_cost_logs_for_export(
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
)
rows = await PrismaLog.prisma().find_many(

View File

@@ -10,6 +10,8 @@ from backend.util.json import SafeJson
from .platform_cost import (
PlatformCostEntry,
_build_prisma_where,
_build_raw_where,
_build_where,
_mask_email,
get_platform_cost_dashboard,
@@ -156,6 +158,101 @@ class TestBuildWhere:
assert 'p."trackingType" = $3' in sql
class TestBuildPrismaWhere:
def test_both_start_and_end(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
where = _build_prisma_where(start, end, None, None)
assert where["createdAt"] == {"gte": start, "lte": end}
def test_end_only(self):
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
where = _build_prisma_where(None, end, None, None)
assert where["createdAt"] == {"lte": end}
def test_start_only(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
where = _build_prisma_where(start, None, None, None)
assert where["createdAt"] == {"gte": start}
def test_no_filters(self):
where = _build_prisma_where(None, None, None, None)
assert "createdAt" not in where
def test_provider_lowercased(self):
where = _build_prisma_where(None, None, "OpenAI", None)
assert where["provider"] == "openai"
def test_model_filter(self):
where = _build_prisma_where(None, None, None, None, model="gpt-4")
assert where["model"] == "gpt-4"
def test_block_name_case_insensitive(self):
where = _build_prisma_where(None, None, None, None, block_name="LLMBlock")
assert where["blockName"] == {"equals": "LLMBlock", "mode": "insensitive"}
def test_tracking_type(self):
where = _build_prisma_where(None, None, None, None, tracking_type="tokens")
assert where["trackingType"] == "tokens"
def test_graph_exec_id_filter(self):
where = _build_prisma_where(None, None, None, None, graph_exec_id="exec-123")
assert where["graphExecId"] == "exec-123"
def test_graph_exec_id_none_not_included(self):
where = _build_prisma_where(None, None, None, None, graph_exec_id=None)
assert "graphExecId" not in where
class TestBuildRawWhere:
def test_end_filter(self):
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_raw_where(None, end, None, None)
assert '"createdAt" <= $2::timestamptz' in sql
assert end in params
def test_model_filter(self):
sql, params = _build_raw_where(None, None, None, None, model="gpt-4")
assert '"model" = $' in sql
assert "gpt-4" in params
def test_block_name_filter(self):
sql, params = _build_raw_where(None, None, None, None, block_name="LLMBlock")
assert 'LOWER("blockName") = LOWER($' in sql
assert "LLMBlock" in params
def test_all_filters_combined(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_raw_where(
start, end, "anthropic", "u1", model="claude-3", block_name="LLM"
)
# trackingType (default), start, end, provider, user_id, model, block_name
assert len(params) == 7
assert "anthropic" in params
assert "u1" in params
assert "claude-3" in params
assert "LLM" in params
def test_default_tracking_type_is_cost_usd(self):
sql, params = _build_raw_where(None, None, None, None)
assert '"trackingType" = $1' in sql
assert params[0] == "cost_usd"
def test_explicit_tracking_type_overrides_default(self):
sql, params = _build_raw_where(None, None, None, None, tracking_type="tokens")
assert params[0] == "tokens"
def test_graph_exec_id_filter(self):
sql, params = _build_raw_where(None, None, None, None, graph_exec_id="exec-abc")
assert '"graphExecId" = $' in sql
assert "exec-abc" in params
def test_graph_exec_id_not_included_when_none(self):
sql, params = _build_raw_where(None, None, None, None)
assert "graphExecId" not in sql
def _make_entry(**overrides: object) -> PlatformCostEntry:
return PlatformCostEntry.model_validate(
{
@@ -286,8 +383,9 @@ class TestGetPlatformCostDashboard:
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[], # by_user_tracking_groups (no cost_usd rows for this user)
[{"userId": "u1"}], # distinct users
[provider_row], # total agg
[provider_row], # total agg (tracking_type=None → same as unfiltered)
]
)
mock_actions.find_many = AsyncMock(return_value=[mock_user])
@@ -301,6 +399,14 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[
[{"p50": 1000, "p75": 2000, "p95": 4000, "p99": 5000}],
[{"bucket": "$0-0.50", "count": 3}],
],
),
):
dashboard = await get_platform_cost_dashboard()
@@ -313,6 +419,131 @@ class TestGetPlatformCostDashboard:
assert dashboard.by_provider[0].total_duration_seconds == 10.5
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].email == "a***@b.com"
assert dashboard.cost_p50_microdollars == 1000
assert dashboard.cost_p75_microdollars == 2000
assert dashboard.cost_p95_microdollars == 4000
assert dashboard.cost_p99_microdollars == 5000
assert len(dashboard.cost_buckets) == 1
# total_input/output_tokens come from total_agg_no_tracking_type_groups
# (provider_row has 1000/500)
assert dashboard.total_input_tokens == 1000
assert dashboard.total_output_tokens == 500
# Token averages must use token_bearing_requests (3) not cost_bearing (0)
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 3)
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 3)
# No cost_usd rows in total_agg → avg_cost should be 0
assert dashboard.avg_cost_microdollars_per_request == 0.0
@pytest.mark.asyncio
async def test_cost_bearing_request_count_nonzero_when_filtering_by_tokens(self):
"""When filtering by tracking_type='tokens', cost_bearing_request_count
must still reflect cost_usd rows because by_user_tracking_groups is
queried without the tracking_type constraint."""
# total_agg only has a tokens row (because of the tracking_type filter)
total_row = _make_group_by_row(
provider="openai", tracking_type="tokens", cost=0, count=5
)
# by_user_tracking_groups returns BOTH rows (no tracking_type filter)
user_tracking_cost_usd_row = {
"_count": {"_all": 7},
"userId": "u1",
"trackingType": "cost_usd",
}
user_tracking_tokens_row = {
"_count": {"_all": 5},
"userId": "u1",
"trackingType": "tokens",
}
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[total_row], # by_provider
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
[
user_tracking_cost_usd_row,
user_tracking_tokens_row,
], # by_user_tracking
[{"userId": "u1"}], # distinct users
[total_row], # total agg (filtered)
[total_row], # total agg (no tracking_type filter)
]
)
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
# by_user has 1 user with 5 total requests (tokens rows only due to filter)
# but per-user cost_bearing count should be 7 (from cost_usd rows in
# by_user_tracking_groups which uses where_no_tracking_type)
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].cost_bearing_request_count == 7
@pytest.mark.asyncio
async def test_global_avg_cost_nonzero_when_filtering_by_tokens(self):
"""When filtering by tracking_type='tokens', avg_cost_microdollars_per_request
must still reflect cost_usd rows from total_agg_no_tracking_type_groups,
not the filtered total_agg_groups which only has tokens rows."""
# filtered total_agg only has tokens rows (zero cost)
tokens_row = _make_group_by_row(
provider="openai", tracking_type="tokens", cost=0, count=5
)
# unfiltered total_agg has both rows (cost_usd carries the actual cost)
cost_usd_row = _make_group_by_row(
provider="openai", tracking_type="cost_usd", cost=10_000, count=4
)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[tokens_row], # by_provider
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
[], # by_user_tracking_groups
[{"userId": "u1"}], # distinct users
[tokens_row], # total agg (filtered — tokens only)
[tokens_row, cost_usd_row], # total agg (no tracking_type filter)
]
)
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
# avg_cost_microdollars_per_request must be non-zero: cost_usd row
# (10_000 microdollars, 4 requests) is present in the unfiltered agg.
assert dashboard.avg_cost_microdollars_per_request == pytest.approx(10_000 / 4)
# avg token stats use token_bearing_requests from unfiltered agg (5)
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 5)
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 5)
@pytest.mark.asyncio
async def test_cache_tokens_aggregated_not_hardcoded(self):
@@ -335,8 +566,9 @@ class TestGetPlatformCostDashboard:
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[], # by_user_tracking_groups
[{"userId": "u2"}], # distinct users
[provider_row], # total agg
[provider_row], # total agg (tracking_type=None → same as unfiltered)
]
)
mock_actions.find_many = AsyncMock(return_value=[])
@@ -350,6 +582,14 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[
[{"p50": 0, "p75": 0, "p95": 0, "p99": 0}],
[],
],
),
):
dashboard = await get_platform_cost_dashboard()
@@ -361,7 +601,7 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
@@ -373,6 +613,11 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
dashboard = await get_platform_cost_dashboard()
@@ -381,13 +626,56 @@ class TestGetPlatformCostDashboard:
assert dashboard.total_users == 0
assert dashboard.by_provider == []
assert dashboard.by_user == []
assert dashboard.cost_p50_microdollars == 0
assert dashboard.cost_buckets == []
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
raw_mock = AsyncMock(side_effect=[[], []])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
raw_mock,
),
):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users,
# total agg filtered); the 6th call (total agg no-tracking-type) only runs
# when tracking_type is set.
assert mock_actions.group_by.await_count == 5
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})
# Raw SQL queries should receive provider and user_id as parameters
assert raw_mock.await_count == 2
raw_call_args = raw_mock.call_args_list[0][0] # positional args of 1st call
raw_params = raw_call_args[1:] # first arg is the query template
assert "openai" in raw_params
assert "u1" in raw_params
@pytest.mark.asyncio
async def test_user_tracking_groups_excludes_tracking_type_filter(self):
"""by_user_tracking_groups must NOT apply the tracking_type filter so that
cost_usd rows are always included even when the caller filters by 'tokens'."""
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
@@ -399,16 +687,54 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
await get_platform_cost_dashboard(tracking_type="tokens")
# group_by called 4 times (by_provider, by_user, distinct users, totals)
assert mock_actions.group_by.await_count == 4
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})
# Call index 2 is by_user_tracking_groups (0=by_provider, 1=by_user,
# 2=by_user_tracking, 3=distinct_users, 4=total_agg, 5=total_agg_no_tt).
tracking_call_where = mock_actions.group_by.call_args_list[2][1]["where"]
# The main filter applies trackingType; by_user_tracking must NOT.
assert "trackingType" not in tracking_call_where
# Other filters (e.g., date range, provider) are still passed through.
# The first call (by_provider) should have trackingType in its where dict.
provider_call_where = mock_actions.group_by.call_args_list[0][1]["where"]
assert "trackingType" in provider_call_where
@pytest.mark.asyncio
async def test_graph_exec_id_filter_passed_to_queries(self):
"""graph_exec_id must be forwarded to both prisma and raw SQL queries."""
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
raw_mock = AsyncMock(side_effect=[[], []])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
raw_mock,
),
):
await get_platform_cost_dashboard(graph_exec_id="exec-xyz")
# Prisma groupBy where must include graphExecId
first_call_where = mock_actions.group_by.call_args_list[0][1]["where"]
assert first_call_where.get("graphExecId") == "exec-xyz"
# Raw SQL params must include the exec id
raw_params = raw_mock.call_args_list[0][0][1:]
assert "exec-xyz" in raw_params
def _make_prisma_log_row(
@@ -509,6 +835,21 @@ class TestGetPlatformCostLogs:
# start provided — should appear in the where filter
assert "createdAt" in where
@pytest.mark.asyncio
async def test_graph_exec_id_filter(self):
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=0)
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(graph_exec_id="exec-abc")
where = mock_actions.count.call_args[1]["where"]
assert where.get("graphExecId") == "exec-abc"
class TestGetPlatformCostLogsForExport:
@pytest.mark.asyncio
@@ -594,6 +935,24 @@ class TestGetPlatformCostLogsForExport:
assert logs[0].cache_read_tokens == 50
assert logs[0].cache_creation_tokens == 25
@pytest.mark.asyncio
async def test_graph_exec_id_filter(self):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export(
graph_exec_id="exec-xyz"
)
where = mock_actions.find_many.call_args[1]["where"]
assert where.get("graphExecId") == "exec-xyz"
assert logs == []
assert truncated is False
@pytest.mark.asyncio
async def test_explicit_start_skips_default(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)

View File

@@ -0,0 +1,509 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Any, cast
from backend.blocks import get_block
from backend.blocks._base import Block
from backend.blocks.io import AgentOutputBlock
from backend.data import redis_client as redis
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionStatus,
GraphExecutionEntry,
NodeExecutionEntry,
)
from backend.data.graph import Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
ZeroBalanceData,
)
from backend.notifications.notifications import queue_notification
from backend.util.clients import (
get_database_manager_client,
get_notification_manager_client,
)
from backend.util.exceptions import InsufficientBalanceError
from backend.util.logging import TruncatedLogger
from backend.util.metrics import DiscordChannel
from backend.util.settings import Settings
from .utils import LogMetadata, block_usage_cost, execution_usage_cost
if TYPE_CHECKING:
from backend.data.db_manager import DatabaseManagerClient
_logger = logging.getLogger(__name__)
logger = TruncatedLogger(_logger, prefix="[Billing]")
settings = Settings()
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
# Hard cap on the multiplier passed to charge_extra_runtime_cost to
# protect against a corrupted llm_call_count draining a user's balance.
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
# 200 leaves headroom while preventing runaway charges.
_MAX_EXTRA_RUNTIME_COST = 200
def get_db_client() -> "DatabaseManagerClient":
return get_database_manager_client()
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
def resolve_block_cost(
node_exec: NodeExecutionEntry,
) -> tuple["Block | None", int, dict[str, Any]]:
"""Look up the block and compute its base usage cost for an exec.
Shared by charge_usage and charge_extra_runtime_cost so the
(get_block, block_usage_cost) lookup lives in exactly one place.
Returns ``(block, cost, matching_filter)``. ``block`` is ``None`` if
the block id can't be resolved — callers should treat that as
"nothing to charge".
"""
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return None, 0, {}
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.inputs)
return block, cost, matching_filter
def charge_usage(
node_exec: NodeExecutionEntry,
execution_count: int,
) -> tuple[int, int]:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block, cost, matching_filter = resolve_block_cost(node_exec)
if not block:
return total_cost, 0
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
total_cost += cost
# execution_count=0 is used by charge_node_usage for nested tool calls
# which must not be pushed into higher execution-count tiers.
# execution_usage_cost(0) would trigger a charge because 0 % threshold == 0,
# so skip it entirely when execution_count is 0.
cost, usage_count = (
execution_usage_cost(execution_count) if execution_count > 0 else (0, 0)
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
total_cost += cost
return total_cost, remaining_balance
def _charge_extra_runtime_cost_sync(
node_exec: NodeExecutionEntry,
capped_count: int,
) -> tuple[int, int]:
"""Synchronous implementation — runs in a thread-pool worker.
Called only from charge_extra_runtime_cost. Do not call directly from
async code.
Note: ``resolve_block_cost`` is called again here (rather than reusing
the result from ``charge_usage`` at the start of execution) because the
two calls happen in separate thread-pool workers and sharing mutable
state across workers would require locks. The block config is immutable
during a run, so the repeated lookup is safe and produces the same cost;
the only overhead is an extra registry lookup.
"""
db_client = get_db_client()
block, cost, matching_filter = resolve_block_cost(node_exec)
if not block or cost <= 0:
return 0, 0
total_extra_cost = cost * capped_count
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=total_extra_cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input={
**matching_filter,
"extra_runtime_cost_count": capped_count,
},
reason=(
f"Extra agent-mode iterations for {block.name} "
f"({capped_count} additional LLM calls)"
),
),
)
return total_extra_cost, remaining_balance
async def charge_extra_runtime_cost(
node_exec: NodeExecutionEntry,
extra_count: int,
) -> tuple[int, int]:
"""Charge a block extra runtime cost beyond the initial run.
Used by agent-mode blocks (e.g. OrchestratorBlock) that make multiple
LLM calls within a single node execution. The first iteration is already
charged by charge_usage; this method charges *extra_count* additional
copies of the block's base cost.
Returns ``(total_extra_cost, remaining_balance)``. May raise
``InsufficientBalanceError`` if the user can't afford the charge.
"""
if extra_count <= 0:
return 0, 0
# Cap to protect against a corrupted llm_call_count.
capped = min(extra_count, _MAX_EXTRA_RUNTIME_COST)
if extra_count > _MAX_EXTRA_RUNTIME_COST:
logger.warning(
f"extra_count {extra_count} exceeds cap {_MAX_EXTRA_RUNTIME_COST};"
f" charging {_MAX_EXTRA_RUNTIME_COST} (llm_call_count may be corrupted)"
)
return await asyncio.to_thread(_charge_extra_runtime_cost_sync, node_exec, capped)
async def charge_node_usage(node_exec: NodeExecutionEntry) -> tuple[int, int]:
"""Charge a single node execution to the user.
Public async wrapper around charge_usage for blocks (e.g. the
OrchestratorBlock) that spawn nested node executions outside the main
queue and therefore need to charge them explicitly.
Also handles low-balance notification so callers don't need to touch
private functions directly.
Note: this **does not** increment the global execution counter
(``increment_execution_count``). Nested tool executions are sub-steps
of a single block run from the user's perspective and should not push
them into higher per-execution cost tiers.
"""
def _run():
total_cost, remaining = charge_usage(node_exec, 0)
if total_cost > 0:
handle_low_balance(
get_db_client(), node_exec.user_id, remaining, total_cost
)
return total_cost, remaining
return await asyncio.to_thread(_run)
async def try_send_insufficient_funds_notif(
user_id: str,
graph_id: str,
error: InsufficientBalanceError,
log_metadata: LogMetadata,
) -> None:
"""Send an insufficient-funds notification, swallowing failures."""
try:
await asyncio.to_thread(
handle_insufficient_funds_notif,
get_db_client(),
user_id,
graph_id,
error,
)
except Exception as notif_error: # pragma: no cover
log_metadata.warning(
f"Failed to send insufficient funds notification: {notif_error}"
)
async def handle_post_execution_billing(
node: Node,
node_exec: NodeExecutionEntry,
execution_stats: NodeExecutionStats,
status: ExecutionStatus,
log_metadata: LogMetadata,
) -> None:
"""Charge extra runtime cost for blocks that opt into per-LLM-call billing.
The first LLM call is already covered by charge_usage(); each additional
call costs another base_cost. Skipped for dry runs and failed runs.
InsufficientBalanceError here is a post-hoc billing leak: the work is
already done but the user can no longer pay. The run stays COMPLETED and
the error is logged with ``billing_leak: True`` for alerting.
"""
extra_iterations = (
cast(Block, node.block).extra_runtime_cost(execution_stats)
if status == ExecutionStatus.COMPLETED
and not node_exec.execution_context.dry_run
else 0
)
if extra_iterations <= 0:
return
try:
extra_cost, remaining_balance = await charge_extra_runtime_cost(
node_exec,
extra_iterations,
)
if extra_cost > 0:
execution_stats.extra_cost += extra_cost
await asyncio.to_thread(
handle_low_balance,
get_db_client(),
node_exec.user_id,
remaining_balance,
extra_cost,
)
except InsufficientBalanceError as e:
log_metadata.error(
"billing_leak: insufficient balance after "
f"{node.block.name} completed {extra_iterations} "
f"extra iterations",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_runtime_cost_count": extra_iterations,
"error": str(e),
},
)
# Do NOT set execution_stats.error — the node ran to completion,
# only the post-hoc charge failed. See class-level billing-leak
# contract documentation.
await try_send_insufficient_funds_notif(
node_exec.user_id,
node_exec.graph_id,
e,
log_metadata,
)
except Exception as e:
log_metadata.error(
f"billing_leak: failed to charge extra iterations for {node.block.name}",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_runtime_cost_count": extra_iterations,
"error_type": type(e).__name__,
"error": str(e),
},
exc_info=True,
)
def handle_agent_run_notif(
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
) -> None:
metadata = db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
]
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)
def handle_insufficient_funds_notif(
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
e: InsufficientBalanceError,
) -> None:
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.ZERO_BALANCE,
data=ZeroBalanceData(
current_balance=e.balance,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
),
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as alert_error:
logger.error(f"Failed to send insufficient funds Discord alert: {alert_error}")
def handle_low_balance(
db_client: "DatabaseManagerClient",
user_id: str,
current_balance: int,
transaction_cost: int,
) -> None:
"""Check and handle low balance scenarios after a transaction"""
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
balance_before = current_balance + transaction_cost
if (
current_balance < LOW_BALANCE_THRESHOLD
and balance_before >= LOW_BALANCE_THRESHOLD
):
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=current_balance,
billing_page_link=f"{base_url}/profile/credits",
),
)
)
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.warning(f"Failed to send low balance Discord alert: {e}")

View File

@@ -21,11 +21,9 @@ from sentry_sdk.api import get_current_scope as _sentry_get_current_scope
from backend.blocks import get_block
from backend.blocks._base import BlockSchema
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.blocks.mcp.block import MCPToolBlock
from backend.data import redis_client as redis
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.credit import UsageTransactionMetadata
from backend.data.dynamic_fields import parse_execution_output
from backend.data.execution import (
ExecutionContext,
@@ -39,27 +37,18 @@ from backend.data.execution import (
)
from backend.data.graph import Link, Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.cost_tracking import (
drain_pending_cost_logs,
log_system_credential_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
get_notification_manager_client,
)
from backend.util.decorator import (
async_error_logged,
@@ -75,7 +64,6 @@ from backend.util.exceptions import (
)
from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import (
continuous_retry,
@@ -84,6 +72,7 @@ from backend.util.retry import (
)
from backend.util.settings import Settings
from . import billing
from .activity_status_generator import generate_activity_status_for_execution
from .automod.manager import automod_manager
from .cluster_lock import ClusterLock
@@ -98,9 +87,7 @@ from .utils import (
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
@@ -126,40 +113,6 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -681,12 +634,16 @@ class ExecutionProcessor:
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
await billing.handle_post_execution_billing(
node, node_exec, execution_stats, status, log_metadata
)
graph_stats, graph_stats_lock = graph_stats_pair
with graph_stats_lock:
graph_stats.node_count += 1 + execution_stats.extra_steps
graph_stats.nodes_cputime += execution_stats.cputime
graph_stats.nodes_walltime += execution_stats.walltime
graph_stats.cost += execution_stats.extra_cost
graph_stats.cost += execution_stats.cost + execution_stats.extra_cost
if isinstance(execution_stats.error, Exception):
graph_stats.node_error_count += 1
@@ -716,6 +673,18 @@ class ExecutionProcessor:
db_client=db_client,
)
# If the node failed because a nested tool charge raised IBE,
# send the user notification so they understand why the run stopped.
if status == ExecutionStatus.FAILED and isinstance(
execution_stats.error, InsufficientBalanceError
):
await billing.try_send_insufficient_funds_notif(
node_exec.user_id,
node_exec.graph_id,
execution_stats.error,
log_metadata,
)
return execution_stats
@async_time_measured
@@ -935,7 +904,7 @@ class ExecutionProcessor:
)
finally:
# Communication handling
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
billing.handle_agent_run_notif(db_client, graph_exec, exec_stats)
update_graph_execution_state(
db_client=db_client,
@@ -944,57 +913,18 @@ class ExecutionProcessor:
stats=exec_stats,
)
def _charge_usage(
async def charge_node_usage(
self,
node_exec: NodeExecutionEntry,
execution_count: int,
) -> tuple[int, int]:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return total_cost, 0
return await billing.charge_node_usage(node_exec)
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
total_cost += cost
cost, usage_count = execution_usage_cost(execution_count)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
total_cost += cost
return total_cost, remaining_balance
async def charge_extra_runtime_cost(
self,
node_exec: NodeExecutionEntry,
extra_count: int,
) -> tuple[int, int]:
return await billing.charge_extra_runtime_cost(node_exec, extra_count)
@time_measured
def _on_graph_execution(
@@ -1106,7 +1036,7 @@ class ExecutionProcessor:
# Charge usage (may raise) — skipped for dry runs
try:
if not graph_exec.execution_context.dry_run:
cost, remaining_balance = self._charge_usage(
cost, remaining_balance = billing.charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(
graph_exec.user_id
@@ -1115,7 +1045,7 @@ class ExecutionProcessor:
with execution_stats_lock:
execution_stats.cost += cost
# Check if we crossed the low balance threshold
self._handle_low_balance(
billing.handle_low_balance(
db_client=db_client,
user_id=graph_exec.user_id,
current_balance=remaining_balance,
@@ -1135,7 +1065,7 @@ class ExecutionProcessor:
status=ExecutionStatus.FAILED,
)
self._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client,
graph_exec.user_id,
graph_exec.graph_id,
@@ -1397,165 +1327,6 @@ class ExecutionProcessor:
):
execution_queue.add(next_execution)
def _handle_agent_run_notif(
self,
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
]
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)
def _handle_insufficient_funds_notif(
self,
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.ZERO_BALANCE,
data=ZeroBalanceData(
current_balance=e.balance,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
),
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as alert_error:
logger.error(
f"Failed to send insufficient funds Discord alert: {alert_error}"
)
def _handle_low_balance(
self,
db_client: "DatabaseManagerClient",
user_id: str,
current_balance: int,
transaction_cost: int,
):
"""Check and handle low balance scenarios after a transaction"""
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
balance_before = current_balance + transaction_cost
if (
current_balance < LOW_BALANCE_THRESHOLD
and balance_before >= LOW_BALANCE_THRESHOLD
):
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=current_balance,
billing_page_link=f"{base_url}/profile/credits",
),
)
)
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.warning(f"Failed to send low balance Discord alert: {e}")
class ExecutionManager(AppProcess):
def __init__(self):

View File

@@ -4,9 +4,9 @@ import pytest
from prisma.enums import NotificationType
from backend.data.notifications import ZeroBalanceData
from backend.executor.manager import (
from backend.executor import billing
from backend.executor.billing import (
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
ExecutionProcessor,
clear_insufficient_funds_notifications,
)
from backend.util.exceptions import InsufficientBalanceError
@@ -25,7 +25,6 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
):
"""Test that the first insufficient funds notification sends a Discord alert."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
@@ -36,13 +35,13 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
)
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Setup mocks
@@ -63,7 +62,7 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
@@ -99,7 +98,6 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
):
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
@@ -110,13 +108,13 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
)
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Setup mocks
@@ -134,7 +132,7 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
@@ -154,7 +152,6 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
):
"""Test that different agents for the same user get separate Discord alerts."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id_1 = "test-graph-111"
graph_id_2 = "test-graph-222"
@@ -166,12 +163,12 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
with patch("backend.executor.billing.queue_notification"), patch(
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
mock_client = MagicMock()
@@ -190,7 +187,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# First agent notification
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_1,
@@ -198,7 +195,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
)
# Second agent notification
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_2,
@@ -227,7 +224,7 @@ async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
with patch("backend.executor.billing.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
@@ -263,7 +260,7 @@ async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestSe
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
with patch("backend.executor.billing.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
@@ -290,7 +287,7 @@ async def test_clear_insufficient_funds_notifications_handles_redis_error(
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
with patch("backend.executor.billing.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
@@ -310,7 +307,6 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
):
"""Test that both email and Discord notifications are still sent when Redis fails."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
@@ -321,13 +317,13 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
)
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
mock_client = MagicMock()
@@ -346,7 +342,7 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
@@ -370,7 +366,7 @@ async def test_add_transaction_clears_notifications_on_grant(server: SpinTestSer
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -412,7 +408,7 @@ async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestSe
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -450,7 +446,7 @@ async def test_add_transaction_skips_clearing_for_inactive_transaction(
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -486,7 +482,7 @@ async def test_add_transaction_skips_clearing_for_usage_transaction(
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -521,7 +517,7 @@ async def test_enable_transaction_clears_notifications(server: SpinTestServer):
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
) as mock_query, patch("backend.executor.billing.redis") as mock_redis_module:
# Mock finding the pending transaction
mock_transaction = MagicMock()

View File

@@ -4,26 +4,25 @@ import pytest
from prisma.enums import NotificationType
from backend.data.notifications import LowBalanceData
from backend.executor.manager import ExecutionProcessor
from backend.executor import billing
from backend.util.test import SpinTestServer
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
"""Test that _handle_low_balance triggers notification when crossing threshold."""
"""Test that handle_low_balance triggers notification when crossing threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 400 # $4 - below $5 threshold
transaction_cost = 600 # $6 transaction
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings:
# Setup mocks
@@ -37,7 +36,7 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the low balance handler
execution_processor._handle_low_balance(
billing.handle_low_balance(
db_client=mock_db_client,
user_id=user_id,
current_balance=current_balance,
@@ -69,7 +68,6 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
):
"""Test that no notification is sent when not crossing the threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 600 # $6 - above $5 threshold
transaction_cost = (
@@ -78,11 +76,11 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings:
# Setup mocks
@@ -94,7 +92,7 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
mock_db_client = MagicMock()
# Test the low balance handler
execution_processor._handle_low_balance(
billing.handle_low_balance(
db_client=mock_db_client,
user_id=user_id,
current_balance=current_balance,
@@ -112,7 +110,6 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
):
"""Test that no notification is sent when already below threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 300 # $3 - below $5 threshold
transaction_cost = (
@@ -121,11 +118,11 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings:
# Setup mocks
@@ -137,7 +134,7 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
mock_db_client = MagicMock()
# Test the low balance handler
execution_processor._handle_low_balance(
billing.handle_low_balance(
db_client=mock_db_client,
user_id=user_id,
current_balance=current_balance,

View File

@@ -18,9 +18,13 @@ images: {
"""
import asyncio
import json
import random
from pathlib import Path
from typing import Any, Dict, List
import prisma.enums as prisma_enums
import prisma.models as prisma_models
from faker import Faker
# Import API functions from the backend
@@ -30,10 +34,12 @@ from backend.api.features.store.db import (
create_store_submission,
review_store_submission,
)
from backend.api.features.store.model import StoreSubmission
from backend.blocks.io import AgentInputBlock
from backend.data.auth.api_key import create_api_key
from backend.data.credit import get_user_credit_model
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.graph import Graph, Link, Node, create_graph, make_graph_model
from backend.data.user import get_or_create_user
from backend.util.clients import get_supabase
@@ -60,6 +66,31 @@ MAX_REVIEWS_PER_VERSION = 5
GUARANTEED_FEATURED_AGENTS = 8
GUARANTEED_FEATURED_CREATORS = 5
GUARANTEED_TOP_AGENTS = 10
E2E_MARKETPLACE_CREATOR_EMAIL = "test123@example.com"
E2E_MARKETPLACE_CREATOR_USERNAME = "e2e-marketplace"
E2E_MARKETPLACE_AGENT_SLUG = "e2e-calculator-agent"
E2E_MARKETPLACE_AGENT_NAME = "E2E Calculator Agent"
E2E_MARKETPLACE_AGENT_INPUT_VALUE = 8
E2E_MARKETPLACE_AGENT_OUTPUT_VALUE = 42
_LOCAL_TEMPLATE_PATH = (
Path(__file__).resolve().parents[1] / "agents" / "calculator-agent.json"
)
_DOCKER_TEMPLATE_PATH = Path(
"/app/autogpt_platform/backend/agents/calculator-agent.json"
)
E2E_MARKETPLACE_AGENT_TEMPLATE_PATH = (
_LOCAL_TEMPLATE_PATH if _LOCAL_TEMPLATE_PATH.exists() else _DOCKER_TEMPLATE_PATH
)
SEEDED_TEST_EMAILS = [
"test123@example.com",
"e2e.qa.auth@example.com",
"e2e.qa.builder@example.com",
"e2e.qa.library@example.com",
"e2e.qa.marketplace@example.com",
"e2e.qa.settings@example.com",
"e2e.qa.parallel.a@example.com",
"e2e.qa.parallel.b@example.com",
]
def get_image():
@@ -100,6 +131,25 @@ def get_category():
return random.choice(categories)
def load_deterministic_marketplace_graph() -> Graph:
graph = Graph.model_validate(
json.loads(E2E_MARKETPLACE_AGENT_TEMPLATE_PATH.read_text())
)
graph.name = E2E_MARKETPLACE_AGENT_NAME
graph.description = (
"Deterministic marketplace calculator graph for Playwright PR E2E coverage."
)
for node in graph.nodes:
if (
node.block_id == AgentInputBlock().id
and node.input_default.get("value") is None
):
node.input_default["value"] = E2E_MARKETPLACE_AGENT_INPUT_VALUE
return graph
class TestDataCreator:
"""Creates test data using API functions for E2E tests."""
@@ -123,9 +173,9 @@ class TestDataCreator:
for i in range(NUM_USERS):
try:
# Generate test user data
if i == 0:
# First user should have test123@gmail.com email for testing
email = "test123@gmail.com"
if i < len(SEEDED_TEST_EMAILS):
# Keep a deterministic pool for Playwright global setup and PR smoke flows
email = SEEDED_TEST_EMAILS[i]
else:
email = faker.unique.email()
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
@@ -547,6 +597,46 @@ class TestDataCreator:
print(f"Error updating profile {profile.id}: {e}")
continue
deterministic_creator = next(
(
user
for user in self.users
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
),
None,
)
if deterministic_creator:
deterministic_profile = next(
(
profile
for profile in existing_profiles
if profile.userId == deterministic_creator["id"]
),
None,
)
if deterministic_profile:
try:
updated_profile = await prisma.profile.update(
where={"id": deterministic_profile.id},
data={
"name": "E2E Marketplace Creator",
"username": E2E_MARKETPLACE_CREATOR_USERNAME,
"description": "Deterministic marketplace creator for Playwright PR E2E coverage.",
"links": ["https://example.com/e2e-marketplace"],
"avatarUrl": get_image(),
"isFeatured": True,
},
)
profiles = [
profile
for profile in profiles
if profile.get("id") != deterministic_profile.id
]
if updated_profile is not None:
profiles.append(updated_profile.model_dump())
except Exception as e:
print(f"Error updating deterministic E2E creator profile: {e}")
self.profiles = profiles
return profiles
@@ -562,58 +652,184 @@ class TestDataCreator:
featured_count = 0
submission_counter = 0
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
# Create a deterministic calculator marketplace agent for PR E2E coverage
test_user = next(
(user for user in self.users if user["email"] == "test123@gmail.com"), None
(
user
for user in self.users
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
),
None,
)
if test_user and self.agent_graphs:
test_submission_data = {
"user_id": test_user["id"],
"graph_id": self.agent_graphs[0]["id"],
"graph_version": 1,
"slug": "test-agent-submission",
"name": "Test Agent Submission",
"sub_heading": "A test agent for frontend testing",
"video_url": "https://www.youtube.com/watch?v=test123",
"image_urls": [
"https://picsum.photos/200/300",
"https://picsum.photos/200/301",
"https://picsum.photos/200/302",
],
"description": "This is a test agent submission specifically created for frontend testing purposes.",
"categories": ["test", "demo", "frontend"],
"changes_summary": "Initial test submission",
}
if test_user:
deterministic_graph = None
try:
test_submission = await create_store_submission(**test_submission_data)
submissions.append(test_submission.model_dump())
print("✅ Created special test store submission for test123@gmail.com")
# ALWAYS approve and feature the test submission
if test_submission.listing_version_id:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.listing_version_id,
is_approved=True,
external_comments="Test submission approved",
internal_comments="Auto-approved test submission",
reviewer_id=test_user["id"],
existing_graph = await prisma_models.AgentGraph.prisma().find_first(
where={
"userId": test_user["id"],
"name": E2E_MARKETPLACE_AGENT_NAME,
"isActive": True,
},
order={"version": "desc"},
)
if existing_graph:
deterministic_graph = {
"id": existing_graph.id,
"version": existing_graph.version,
"name": existing_graph.name,
"userId": test_user["id"],
}
self.agent_graphs.append(deterministic_graph)
print(
"✅ Reused existing deterministic marketplace graph: "
f"{existing_graph.id}"
)
approved_submissions.append(approved_submission.model_dump())
print("✅ Approved test store submission")
await prisma.storelistingversion.update(
where={"id": test_submission.listing_version_id},
data={"isFeatured": True},
else:
deterministic_graph_model = make_graph_model(
load_deterministic_marketplace_graph(),
test_user["id"],
)
featured_count += 1
print("🌟 Marked test agent as FEATURED")
deterministic_graph_model.reassign_ids(
user_id=test_user["id"],
reassign_graph_id=True,
)
created_deterministic_graph = await create_graph(
deterministic_graph_model,
test_user["id"],
)
deterministic_graph = created_deterministic_graph.model_dump()
deterministic_graph["userId"] = test_user["id"]
self.agent_graphs.append(deterministic_graph)
print("✅ Created deterministic marketplace graph")
except Exception as e:
print(f"Error creating test store submission: {e}")
import traceback
print(f"Error creating deterministic marketplace graph: {e}")
traceback.print_exc()
if deterministic_graph is None and self.agent_graphs:
test_user_graphs = [
graph
for graph in self.agent_graphs
if graph.get("userId") == test_user["id"]
]
deterministic_graph = next(
(
graph
for graph in test_user_graphs
if not graph.get("name", "").startswith("DummyInput ")
),
test_user_graphs[0] if test_user_graphs else None,
)
if deterministic_graph:
test_submission_data = {
"user_id": test_user["id"],
"graph_id": deterministic_graph["id"],
"graph_version": deterministic_graph.get("version", 1),
"slug": E2E_MARKETPLACE_AGENT_SLUG,
"name": E2E_MARKETPLACE_AGENT_NAME,
"sub_heading": "A deterministic calculator agent for PR E2E coverage",
"video_url": "https://www.youtube.com/watch?v=test123",
"image_urls": [
"https://picsum.photos/seed/e2e-marketplace-1/200/300",
"https://picsum.photos/seed/e2e-marketplace-2/200/301",
"https://picsum.photos/seed/e2e-marketplace-3/200/302",
],
"description": (
"A deterministic marketplace calculator agent that adds "
f"{E2E_MARKETPLACE_AGENT_INPUT_VALUE} and 34 to produce "
f"{E2E_MARKETPLACE_AGENT_OUTPUT_VALUE} for frontend E2E coverage."
),
"categories": ["test", "demo", "frontend"],
"changes_summary": (
"Initial deterministic calculator submission seeded from "
"backend/agents/calculator-agent.json"
),
}
try:
existing_deterministic_submission = (
await prisma_models.StoreListingVersion.prisma().find_first(
where={
"isDeleted": False,
"StoreListing": {
"is": {
"owningUserId": test_user["id"],
"slug": E2E_MARKETPLACE_AGENT_SLUG,
"isDeleted": False,
}
},
},
include={"StoreListing": True},
order={"version": "desc"},
)
)
if existing_deterministic_submission:
test_submission = StoreSubmission.from_listing_version(
existing_deterministic_submission
)
submissions.append(test_submission.model_dump())
print(
"✅ Reused deterministic marketplace submission: "
f"{E2E_MARKETPLACE_AGENT_NAME}"
)
else:
test_submission = await create_store_submission(
**test_submission_data
)
submissions.append(test_submission.model_dump())
print(
"✅ Created deterministic marketplace submission: "
f"{E2E_MARKETPLACE_AGENT_NAME}"
)
current_status = (
existing_deterministic_submission.submissionStatus
if existing_deterministic_submission
else test_submission.status
)
is_featured = bool(
existing_deterministic_submission
and existing_deterministic_submission.isFeatured
)
if test_submission.listing_version_id:
if current_status != prisma_enums.SubmissionStatus.APPROVED:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.listing_version_id,
is_approved=True,
external_comments="Deterministic calculator submission approved",
internal_comments="Auto-approved PR E2E marketplace submission",
reviewer_id=test_user["id"],
)
approved_submissions.append(
approved_submission.model_dump()
)
print("✅ Approved deterministic marketplace submission")
else:
approved_submissions.append(test_submission.model_dump())
print(
"✅ Deterministic marketplace submission already approved"
)
if is_featured:
featured_count += 1
print("🌟 Deterministic marketplace agent already FEATURED")
else:
await prisma.storelistingversion.update(
where={"id": test_submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
print(
"🌟 Marked deterministic marketplace agent as FEATURED"
)
except Exception as e:
print(f"Error creating deterministic marketplace submission: {e}")
import traceback
traceback.print_exc()
# Create regular submissions for all users
for user in self.users:

View File

@@ -6,7 +6,8 @@
# 5. CLI arguments - docker compose run -e VAR=value
# Common backend environment - Docker service names
x-backend-env: &backend-env # Docker internal service hostnames (override localhost defaults)
x-backend-env:
&backend-env # Docker internal service hostnames (override localhost defaults)
PYRO_HOST: "0.0.0.0"
AGENTSERVER_HOST: rest_server
SCHEDULER_HOST: scheduler_server
@@ -39,7 +40,12 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: migrate
command: ["sh", "-c", "prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy"]
command:
[
"sh",
"-c",
"prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy",
]
develop:
watch:
- path: ./
@@ -79,8 +85,8 @@ services:
falkordb:
image: falkordb/falkordb:latest
ports:
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
- "3001:3000" # FalkorDB web UI
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
- "3001:3000" # FalkorDB web UI
environment:
- REDIS_ARGS=--requirepass ${GRAPHITI_FALKORDB_PASSWORD:-}
volumes:
@@ -88,7 +94,11 @@ services:
networks:
- app-network
healthcheck:
test: ["CMD-SHELL", "redis-cli -p 6379 -a \"${GRAPHITI_FALKORDB_PASSWORD:-}\" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1"]
test:
[
"CMD-SHELL",
'redis-cli -p 6379 -a "${GRAPHITI_FALKORDB_PASSWORD:-}" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1',
]
interval: 10s
timeout: 5s
retries: 5
@@ -300,19 +310,6 @@ services:
condition: service_completed_successfully
database_manager:
condition: service_started
# healthcheck:
# test:
# [
# "CMD",
# "curl",
# "-f",
# "-X",
# "POST",
# "http://localhost:8003/health_check",
# ]
# interval: 10s
# timeout: 10s
# retries: 5
<<: *backend-env-files
environment:
<<: *backend-env

View File

@@ -193,3 +193,4 @@ services:
- copilot_executor
- websocket_server
- database_manager
- scheduler_server

View File

@@ -8,6 +8,7 @@ const config: StorybookConfig = {
"../src/components/molecules/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/ai-elements/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/renderers/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/app/[(]platform[)]/copilot/**/*.stories.@(js|jsx|mjs|ts|tsx)",
],
addons: [
"@storybook/addon-a11y",

View File

@@ -81,8 +81,10 @@ Every time a new Front-end dependency is added by you or others, you will need t
- `pnpm lint` - Run ESLint and Prettier checks
- `pnpm format` - Format code with Prettier
- `pnpm types` - Run TypeScript type checking
- `pnpm test` - Run Playwright tests
- `pnpm test-ui` - Run Playwright tests with UI
- `pnpm test:unit` - Run the Vitest integration and unit suite with coverage
- `pnpm test` - Run the Playwright E2E suite used in CI
- `pnpm test-ui` - Run the same Playwright E2E suite with UI
- `pnpm test:e2e:no-build` - Run the same Playwright E2E suite against a running app
- `pnpm fetch:openapi` - Fetch OpenAPI spec from backend
- `pnpm generate:api-client` - Generate API client from OpenAPI spec
- `pnpm generate:api` - Fetch OpenAPI spec and generate API client

View File

@@ -121,35 +121,49 @@ Only when the component has complex internal logic that is hard to exercise thro
### Running
```bash
pnpm test # build + run all Playwright tests
pnpm test-ui # run with Playwright UI
pnpm test:no-build # run against a running dev server
pnpm test # build + run the Playwright E2E suite used in CI
pnpm test-ui # run the same E2E suite with Playwright UI
pnpm test:e2e:no-build # run the same E2E suite against a running dev server
pnpm exec playwright test # run the same eight-spec Playwright suite directly
```
### Setup
1. Start the backend + Supabase stack:
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
2. Seed rich E2E data (creates `test123@example.com` with library agents):
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
### How Playwright setup works
- Playwright runs from `frontend/playwright.config.ts` with a global setup step
- Global setup creates a user pool via the real signup UI, stored in `frontend/.auth/user-pool.json`
- `getTestUser()` (from `src/tests/utils/auth.ts`) pulls a random user from the pool
- Playwright runs from `frontend/playwright.config.ts` and keeps browser-only code in `frontend/src/playwright/`
- Global setup creates reusable auth states for deterministic seeded accounts in `frontend/.auth/states/`
- `getTestUser()` (from `src/playwright/utils/auth.ts`) picks one seeded account for general auth coverage
- `getTestUserWithLibraryAgents()` uses the rich user created by the data script
### Test users
- **User pool (basic users)** — created automatically by Playwright global setup. Used by `getTestUser()`
- **Seeded E2E accounts** — created by backend fixtures and logged in during Playwright global setup. Used by `getTestUser()` and `E2E_AUTH_STATES`
- **Rich user with library agents** — created by `backend/test/e2e_test_data.py`. Used by `getTestUserWithLibraryAgents()`
### Current Playwright E2E suite
The CI suite is intentionally limited to the cross-page journeys we still require a real browser for. Playwright discovers the PR-gating specs by the `*-happy-path.spec.ts` naming pattern inside `src/playwright/`:
- `src/playwright/auth-happy-path.spec.ts`
- `src/playwright/settings-happy-path.spec.ts`
- `src/playwright/api-keys-happy-path.spec.ts`
- `src/playwright/builder-happy-path.spec.ts`
- `src/playwright/library-happy-path.spec.ts`
- `src/playwright/marketplace-happy-path.spec.ts`
- `src/playwright/publish-happy-path.spec.ts`
- `src/playwright/copilot-happy-path.spec.ts`
### Resetting the DB
If you reset the Docker DB and logins start failing:
1. Delete `frontend/.auth/user-pool.json`
1. Delete `frontend/.auth/states/*` and `frontend/.auth/user-pool.json` if it exists
2. Re-run `poetry run python test/e2e_test_data.py`
## Storybook

View File

@@ -13,11 +13,13 @@
"lint": "next lint && prettier --check .",
"format": "next lint --fix; prettier --write .",
"types": "tsc --noEmit",
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:ui",
"test:unit": "vitest run --coverage",
"test:unit:watch": "vitest",
"test:no-build": "playwright test",
"test:e2e": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
"test:e2e:no-build": "playwright test",
"test:e2e:ui": "playwright test --ui",
"gentests": "playwright codegen http://localhost:3000",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",

View File

@@ -7,10 +7,22 @@ import { defineConfig, devices } from "@playwright/test";
import dotenv from "dotenv";
import fs from "fs";
import path from "path";
import { buildCookieConsentStorageState } from "./src/playwright/credentials/storage-state";
dotenv.config({ path: path.resolve(__dirname, ".env") });
dotenv.config({ path: path.resolve(__dirname, "../backend/.env") });
const frontendRoot = __dirname.replaceAll("\\", "/");
const configuredBaseURL =
process.env.PLAYWRIGHT_BASE_URL ?? "http://localhost:3000";
const parsedBaseURL = new URL(configuredBaseURL);
const baseURL = parsedBaseURL.toString().replace(/\/$/, "");
const baseOrigin = parsedBaseURL.origin;
const jsonReporterOutputFile = process.env.PLAYWRIGHT_JSON_OUTPUT_FILE;
const configuredWorkers = process.env.PLAYWRIGHT_WORKERS
? Number(process.env.PLAYWRIGHT_WORKERS)
: process.env.CI
? 8
: undefined;
// Directory where CI copies .next/static from the Docker container
const staticCoverageDir = path.resolve(__dirname, ".next-static-coverage");
@@ -57,17 +69,18 @@ function resolveSourceMap(sourcePath: string) {
}
export default defineConfig({
testDir: "./src/tests",
testDir: "./src/playwright",
testMatch: /.*-happy-path\.spec\.ts/,
/* Global setup file that runs before all tests */
globalSetup: "./src/tests/global-setup.ts",
globalSetup: "./src/playwright/global-setup.ts",
/* Run tests in files in parallel */
fullyParallel: true,
/* Fail the build on CI if you accidentally left test.only in the source code. */
forbidOnly: !!process.env.CI,
/* Retry on CI only */
retries: process.env.CI ? 1 : 0,
/* use more workers on CI. */
workers: process.env.CI ? 4 : undefined,
retries: process.env.CI ? Number(process.env.PLAYWRIGHT_RETRIES ?? 2) : 0,
/* Higher worker count keeps PR smoke runtime down without sharing page state. */
workers: configuredWorkers,
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
reporter: [
["list"],
@@ -92,40 +105,25 @@ export default defineConfig({
},
},
],
...(jsonReporterOutputFile
? [["json", { outputFile: jsonReporterOutputFile }] as const]
: []),
],
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
use: {
/* Base URL to use in actions like `await page.goto('/')`. */
baseURL: "http://localhost:3000/",
baseURL,
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
screenshot: "only-on-failure",
bypassCSP: true,
/* Helps debugging failures */
trace: "retain-on-failure",
video: "retain-on-failure",
trace: process.env.CI ? "on-first-retry" : "retain-on-failure",
video: process.env.CI ? "off" : "retain-on-failure",
/* Auto-accept cookies in all tests to prevent banner interference */
storageState: {
cookies: [],
origins: [
{
origin: "http://localhost:3000",
localStorage: [
{
name: "autogpt_cookie_consent",
value: JSON.stringify({
hasConsented: true,
timestamp: Date.now(),
analytics: true,
monitoring: true,
}),
},
],
},
],
},
storageState: buildCookieConsentStorageState(baseOrigin),
},
/* Maximum time one test can run for */
timeout: 25000,
@@ -133,7 +131,7 @@ export default defineConfig({
/* Configure web server to start automatically (local dev only) */
webServer: {
command: "pnpm start",
url: "http://localhost:3000",
url: baseURL,
reuseExistingServer: true,
},

View File

@@ -3,6 +3,7 @@ import {
screen,
cleanup,
waitFor,
fireEvent,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { PlatformCostContent } from "../components/PlatformCostContent";
@@ -29,6 +30,16 @@ const emptyDashboard: PlatformCostDashboard = {
total_cost_microdollars: 0,
total_requests: 0,
total_users: 0,
total_input_tokens: 0,
total_output_tokens: 0,
avg_input_tokens_per_request: 0,
avg_output_tokens_per_request: 0,
avg_cost_microdollars_per_request: 0,
cost_p50_microdollars: 0,
cost_p75_microdollars: 0,
cost_p95_microdollars: 0,
cost_p99_microdollars: 0,
cost_buckets: [],
by_provider: [],
by_user: [],
};
@@ -47,6 +58,20 @@ const dashboardWithData: PlatformCostDashboard = {
total_cost_microdollars: 5_000_000,
total_requests: 100,
total_users: 5,
total_input_tokens: 150000,
total_output_tokens: 60000,
avg_input_tokens_per_request: 2500,
avg_output_tokens_per_request: 1000,
avg_cost_microdollars_per_request: 83333,
cost_p50_microdollars: 50000,
cost_p75_microdollars: 100000,
cost_p95_microdollars: 250000,
cost_p99_microdollars: 500000,
cost_buckets: [
{ bucket: "$0-0.50", count: 80 },
{ bucket: "$0.50-1", count: 15 },
{ bucket: "$1-2", count: 5 },
],
by_provider: [
{
provider: "openai",
@@ -75,6 +100,7 @@ const dashboardWithData: PlatformCostDashboard = {
total_input_tokens: 50000,
total_output_tokens: 20000,
request_count: 60,
cost_bearing_request_count: 40,
},
],
};
@@ -134,9 +160,14 @@ describe("PlatformCostContent", () => {
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
// Known Cost and Estimated Total cards render $0.0000
// "Known Cost" appears in both the SummaryCard and the ProviderTable header
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
// All cost summary cards (Known Cost, Estimated Total, Avg Cost,
// Typical/Upper/High/Peak Cost) show $0.0000
const zeroCostItems = screen.getAllByText("$0.0000");
expect(zeroCostItems.length).toBe(2);
expect(zeroCostItems.length).toBe(7);
expect(screen.getByText("No cost data yet")).toBeDefined();
});
@@ -155,7 +186,9 @@ describe("PlatformCostContent", () => {
);
expect(screen.getByText("$5.0000")).toBeDefined();
expect(screen.getByText("100")).toBeDefined();
expect(screen.getByText("5")).toBeDefined();
// "5" appears in multiple places (Active Users card + bucket count),
// so verify at least one element renders it.
expect(screen.getAllByText("5").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("openai")).toBeDefined();
expect(screen.getByText("google_maps")).toBeDefined();
});
@@ -223,10 +256,83 @@ describe("PlatformCostContent", () => {
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Original 4 cards
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
expect(screen.getByText("Total Requests")).toBeDefined();
expect(screen.getByText("Active Users")).toBeDefined();
// New average/token cards
expect(screen.getByText("Avg Cost / Request")).toBeDefined();
expect(screen.getByText("Avg Input Tokens")).toBeDefined();
expect(screen.getByText("Avg Output Tokens")).toBeDefined();
expect(screen.getByText("Total Tokens")).toBeDefined();
// Percentile cards (friendlier labels)
expect(screen.getByText("Typical Cost (P50)")).toBeDefined();
expect(screen.getByText("Upper Cost (P75)")).toBeDefined();
expect(screen.getByText("High Cost (P95)")).toBeDefined();
expect(screen.getByText("Peak Cost (P99)")).toBeDefined();
});
it("renders cost distribution buckets", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Cost Distribution by Bucket")).toBeDefined();
expect(screen.getByText("$0-0.50")).toBeDefined();
expect(screen.getByText("$0.50-1")).toBeDefined();
expect(screen.getByText("$1-2")).toBeDefined();
expect(screen.getByText("80")).toBeDefined();
expect(screen.getByText("15")).toBeDefined();
});
it("renders new summary card values from fixture data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Avg Input Tokens: 2500 formatted
expect(screen.getByText("2,500")).toBeDefined();
// Avg Output Tokens: 1000 formatted
expect(screen.getByText("1,000")).toBeDefined();
// P50 cost: 50000 microdollars = $0.0500
expect(screen.getByText("$0.0500")).toBeDefined();
});
it("renders user table avg cost column with fixture data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// User table should show Avg Cost / Req header
expect(screen.getByText("Avg Cost / Req")).toBeDefined();
// Input/Output token columns
expect(screen.getByText("Input Tokens")).toBeDefined();
expect(screen.getByText("Output Tokens")).toBeDefined();
});
it("renders filter inputs", async () => {
@@ -246,6 +352,95 @@ describe("PlatformCostContent", () => {
expect(screen.getByText("Apply")).toBeDefined();
});
it("renders execution ID filter input", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Execution ID")).toBeDefined();
expect(screen.getByPlaceholderText("Filter by execution")).toBeDefined();
});
it("pre-fills execution ID filter from searchParams", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent({ graph_exec_id: "exec-123" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
const input = screen.getByPlaceholderText(
"Filter by execution",
) as HTMLInputElement;
expect(input.value).toBe("exec-123");
});
it("clears execution ID input on Clear click", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent({ graph_exec_id: "exec-123" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
fireEvent.click(screen.getByText("Clear"));
const input = screen.getByPlaceholderText(
"Filter by execution",
) as HTMLInputElement;
expect(input.value).toBe("");
});
it("passes execution ID to filter on Apply click", async () => {
mockUseGetDashboard.mockReturnValue({
data: emptyDashboard,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
const input = screen.getByPlaceholderText(
"Filter by execution",
) as HTMLInputElement;
fireEvent.change(input, { target: { value: "exec-abc" } });
expect(input.value).toBe("exec-abc");
fireEvent.click(screen.getByText("Apply"));
// After apply, the input still holds the typed value
expect(input.value).toBe("exec-abc");
});
it("copies execution ID to clipboard on cell click in logs tab", async () => {
const writeText = vi.fn().mockResolvedValue(undefined);
vi.stubGlobal("navigator", { ...navigator, clipboard: { writeText } });
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "logs" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// The exec ID cell shows first 8 chars of "gx-123"
const execIdCell = screen.getByText("gx-123".slice(0, 8));
fireEvent.click(execIdCell);
expect(writeText).toHaveBeenCalledWith("gx-123");
vi.unstubAllGlobals();
});
it("renders by-user tab when specified", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,

View File

@@ -67,7 +67,10 @@ function LogsTable({
Cost
</th>
<th scope="col" className="px-3 py-3 text-right">
Tokens
In / Out
</th>
<th scope="col" className="px-3 py-3 text-right">
Cache (R/W)
</th>
<th scope="col" className="px-3 py-3 text-right">
Duration
@@ -105,12 +108,34 @@ function LogsTable({
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
: "-"}
</td>
<td className="px-3 py-2 text-right text-xs">
{log.cache_read_tokens || log.cache_creation_tokens
? `${formatTokens(Number(log.cache_read_tokens ?? 0))} / ${formatTokens(Number(log.cache_creation_tokens ?? 0))}`
: "-"}
</td>
<td className="px-3 py-2 text-right text-xs">
{log.duration != null
? formatDuration(Number(log.duration))
: "-"}
</td>
<td className="px-3 py-2 text-xs text-muted-foreground">
<td
className={[
"px-3 py-2 text-xs text-muted-foreground",
log.graph_exec_id ? "cursor-pointer" : "",
].join(" ")}
title={
log.graph_exec_id ? String(log.graph_exec_id) : undefined
}
onClick={
log.graph_exec_id
? () => {
navigator.clipboard
.writeText(String(log.graph_exec_id))
.catch(() => {});
}
: undefined
}
>
{log.graph_exec_id
? String(log.graph_exec_id).slice(0, 8)
: "-"}
@@ -120,7 +145,7 @@ function LogsTable({
{logs.length === 0 && (
<tr>
<td
colSpan={10}
colSpan={11}
className="px-4 py-8 text-center text-muted-foreground"
>
No logs found

View File

@@ -2,12 +2,13 @@
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { formatMicrodollars } from "../helpers";
import { formatMicrodollars, formatTokens } from "../helpers";
import { SummaryCard } from "./SummaryCard";
import { ProviderTable } from "./ProviderTable";
import { UserTable } from "./UserTable";
import { LogsTable } from "./LogsTable";
import { usePlatformCostContent } from "./usePlatformCostContent";
import type { CostBucket } from "@/app/api/__generated__/models/costBucket";
interface Props {
searchParams: {
@@ -18,6 +19,7 @@ interface Props {
model?: string;
block_name?: string;
tracking_type?: string;
graph_exec_id?: string;
page?: string;
tab?: string;
};
@@ -46,6 +48,8 @@ export function PlatformCostContent({ searchParams }: Props) {
setBlockInput,
typeInput,
setTypeInput,
executionIDInput,
setExecutionIDInput,
rateOverrides,
handleRateOverride,
updateUrl,
@@ -54,6 +58,76 @@ export function PlatformCostContent({ searchParams }: Props) {
handleExport,
} = usePlatformCostContent(searchParams);
const summaryCards: { label: string; value: string; subtitle?: string }[] =
dashboard
? [
{
label: "Known Cost",
value: formatMicrodollars(dashboard.total_cost_microdollars),
subtitle: "From providers that report USD cost",
},
{
label: "Estimated Total",
value: formatMicrodollars(totalEstimatedCost),
subtitle: "Including per-run cost estimates",
},
{
label: "Total Requests",
value: dashboard.total_requests.toLocaleString(),
},
{
label: "Active Users",
value: dashboard.total_users.toLocaleString(),
},
{
label: "Avg Cost / Request",
value: formatMicrodollars(
dashboard.avg_cost_microdollars_per_request ?? 0,
),
subtitle: "Known cost divided by cost-bearing requests",
},
{
label: "Avg Input Tokens",
value: Math.round(
dashboard.avg_input_tokens_per_request ?? 0,
).toLocaleString(),
subtitle: "Prompt tokens per request (context size)",
},
{
label: "Avg Output Tokens",
value: Math.round(
dashboard.avg_output_tokens_per_request ?? 0,
).toLocaleString(),
subtitle: "Completion tokens per request (response length)",
},
{
label: "Total Tokens",
value: `${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`,
subtitle: "Prompt vs completion token split",
},
{
label: "Typical Cost (P50)",
value: formatMicrodollars(dashboard.cost_p50_microdollars ?? 0),
subtitle: "Median cost per request",
},
{
label: "Upper Cost (P75)",
value: formatMicrodollars(dashboard.cost_p75_microdollars ?? 0),
subtitle: "75th percentile cost",
},
{
label: "High Cost (P95)",
value: formatMicrodollars(dashboard.cost_p95_microdollars ?? 0),
subtitle: "95th percentile cost",
},
{
label: "Peak Cost (P99)",
value: formatMicrodollars(dashboard.cost_p99_microdollars ?? 0),
subtitle: "99th percentile cost",
},
]
: [];
return (
<div className="flex flex-col gap-6">
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
@@ -164,6 +238,22 @@ export function PlatformCostContent({ searchParams }: Props) {
onChange={(e) => setTypeInput(e.target.value)}
/>
</div>
<div className="flex flex-col gap-1">
<label
htmlFor="execution-id-filter"
className="text-sm text-muted-foreground"
>
Execution ID
</label>
<input
id="execution-id-filter"
type="text"
placeholder="Filter by execution"
className="rounded border px-3 py-1.5 text-sm"
value={executionIDInput}
onChange={(e) => setExecutionIDInput(e.target.value)}
/>
</div>
<button
onClick={handleFilter}
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
@@ -179,6 +269,7 @@ export function PlatformCostContent({ searchParams }: Props) {
setModelInput("");
setBlockInput("");
setTypeInput("");
setExecutionIDInput("");
updateUrl({
start: "",
end: "",
@@ -187,6 +278,7 @@ export function PlatformCostContent({ searchParams }: Props) {
model: "",
block_name: "",
tracking_type: "",
graph_exec_id: "",
page: "1",
});
}}
@@ -204,37 +296,54 @@ export function PlatformCostContent({ searchParams }: Props) {
{loading ? (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
{[...Array(4)].map((_, i) => (
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
{/* 12 skeleton placeholders — one per summary card */}
{Array.from({ length: 12 }, (_, i) => (
<Skeleton key={i} className="h-20 rounded-lg" />
))}
</div>
<Skeleton className="h-32 rounded-lg" />
<Skeleton className="h-8 w-48 rounded" />
<Skeleton className="h-64 rounded-lg" />
</div>
) : (
<>
{dashboard && (
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
<SummaryCard
label="Known Cost"
value={formatMicrodollars(dashboard.total_cost_microdollars)}
subtitle="From providers that report USD cost"
/>
<SummaryCard
label="Estimated Total"
value={formatMicrodollars(totalEstimatedCost)}
subtitle="Including per-run cost estimates"
/>
<SummaryCard
label="Total Requests"
value={dashboard.total_requests.toLocaleString()}
/>
<SummaryCard
label="Active Users"
value={dashboard.total_users.toLocaleString()}
/>
</div>
<>
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
{summaryCards.map((card) => (
<SummaryCard
key={card.label}
label={card.label}
value={card.value}
subtitle={card.subtitle}
/>
))}
</div>
{dashboard.cost_buckets && dashboard.cost_buckets.length > 0 && (
<div className="rounded-lg border p-4">
<h3 className="mb-3 text-sm font-medium">
Cost Distribution by Bucket
</h3>
<div className="grid grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-6">
{dashboard.cost_buckets.map((b: CostBucket) => (
<div
key={b.bucket}
className="flex flex-col items-center rounded border p-2 text-center"
>
<span className="text-xs text-muted-foreground">
{b.bucket}
</span>
<span className="text-lg font-semibold">
{b.count.toLocaleString()}
</span>
</div>
))}
</div>
</div>
)}
</>
)}
<div

View File

@@ -3,6 +3,7 @@ import {
defaultRateFor,
estimateCostForRow,
formatMicrodollars,
formatTokens,
rateKey,
rateUnitLabel,
trackingValue,
@@ -33,6 +34,20 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
<th scope="col" className="px-4 py-3 text-right">
Usage
</th>
<th
scope="col"
className="px-4 py-3 text-right"
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
>
Input Tokens
</th>
<th
scope="col"
className="px-4 py-3 text-right"
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
>
Output Tokens
</th>
<th scope="col" className="px-4 py-3 text-right">
Requests
</th>
@@ -74,6 +89,16 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
<TrackingBadge trackingType={row.tracking_type} />
</td>
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
<td className="px-4 py-3 text-right">
{row.total_input_tokens > 0
? formatTokens(row.total_input_tokens)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{row.total_output_tokens > 0
? formatTokens(row.total_output_tokens)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{row.request_count.toLocaleString()}
</td>
@@ -124,7 +149,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
{data.length === 0 && (
<tr>
<td
colSpan={8}
colSpan={10}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet

View File

@@ -26,6 +26,9 @@ function UserTable({ data }: Props) {
<th scope="col" className="px-4 py-3 text-right">
Output Tokens
</th>
<th scope="col" className="px-4 py-3 text-right">
Avg Cost / Req
</th>
</tr>
</thead>
<tbody>
@@ -54,12 +57,21 @@ function UserTable({ data }: Props) {
<td className="px-4 py-3 text-right">
{formatTokens(row.total_output_tokens)}
</td>
<td className="px-4 py-3 text-right">
{(row.cost_bearing_request_count ?? 0) > 0 &&
row.total_cost_microdollars > 0
? formatMicrodollars(
row.total_cost_microdollars /
(row.cost_bearing_request_count ?? 1),
)
: "-"}
</td>
</tr>
))}
{data.length === 0 && (
<tr>
<td
colSpan={5}
colSpan={6}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet

View File

@@ -23,6 +23,7 @@ interface InitialSearchParams {
model?: string;
block_name?: string;
tracking_type?: string;
graph_exec_id?: string;
page?: string;
tab?: string;
}
@@ -43,6 +44,8 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
urlParams.get("block_name") || searchParams.block_name || "";
const typeFilter =
urlParams.get("tracking_type") || searchParams.tracking_type || "";
const executionIDFilter =
urlParams.get("graph_exec_id") || searchParams.graph_exec_id || "";
const [startInput, setStartInput] = useState(toLocalInput(startDate));
const [endInput, setEndInput] = useState(toLocalInput(endDate));
@@ -51,6 +54,7 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
const [modelInput, setModelInput] = useState(modelFilter);
const [blockInput, setBlockInput] = useState(blockFilter);
const [typeInput, setTypeInput] = useState(typeFilter);
const [executionIDInput, setExecutionIDInput] = useState(executionIDFilter);
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
{},
);
@@ -67,6 +71,7 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
model: modelFilter || undefined,
block_name: blockFilter || undefined,
tracking_type: typeFilter || undefined,
graph_exec_id: executionIDFilter || undefined,
};
const {
@@ -115,6 +120,7 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
model: modelInput,
block_name: blockInput,
tracking_type: typeInput,
graph_exec_id: executionIDInput,
page: "1",
});
}
@@ -185,6 +191,8 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
setBlockInput,
typeInput,
setTypeInput,
executionIDInput,
setExecutionIDInput,
rateOverrides,
handleRateOverride,
updateUrl,

View File

@@ -7,6 +7,10 @@ type SearchParams = {
end?: string;
provider?: string;
user_id?: string;
model?: string;
block_name?: string;
tracking_type?: string;
graph_exec_id?: string;
page?: string;
tab?: string;
};

View File

@@ -113,8 +113,8 @@ export function CopilotPage() {
// Rate limit reset
rateLimitMessage,
dismissRateLimit,
// Dry run dev toggle
isDryRun,
// Dry run session state
sessionDryRun,
} = useCopilotPage();
const {
@@ -176,10 +176,15 @@ export function CopilotPage() {
>
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<NotificationBanner />
{isDryRun && (
{/* Test mode banner: only shown when the CURRENT session is confirmed to be
a dry_run session via its immutable metadata. Never shown based on the
global isDryRun store preference alone — that only predicts future sessions
and would mislead users browsing non-dry-run sessions while the toggle is on.
The DryRunToggleButton (visible on new chats) already communicates the preference. */}
{sessionId && sessionDryRun && (
<div className="flex items-center justify-center gap-1.5 bg-amber-50 px-3 py-1.5 text-xs font-medium text-amber-800">
<Flask size={13} weight="bold" />
Test mode new sessions use dry_run=true
Test mode this session runs agents as simulation
</div>
)}
{/* Drop overlay */}

View File

@@ -0,0 +1,168 @@
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { CopilotPage } from "../CopilotPage";
// Mock child components that are complex and not under test here
vi.mock("../components/ChatContainer/ChatContainer", () => ({
ChatContainer: () => <div data-testid="chat-container" />,
}));
vi.mock("../components/ChatSidebar/ChatSidebar", () => ({
ChatSidebar: () => <div data-testid="chat-sidebar" />,
}));
vi.mock("../components/DeleteChatDialog/DeleteChatDialog", () => ({
DeleteChatDialog: () => null,
}));
vi.mock("../components/MobileDrawer/MobileDrawer", () => ({
MobileDrawer: () => null,
}));
vi.mock("../components/MobileHeader/MobileHeader", () => ({
MobileHeader: () => null,
}));
vi.mock("../components/NotificationBanner/NotificationBanner", () => ({
NotificationBanner: () => null,
}));
vi.mock("../components/NotificationDialog/NotificationDialog", () => ({
NotificationDialog: () => null,
}));
vi.mock("../components/RateLimitResetDialog/RateLimitResetDialog", () => ({
RateLimitResetDialog: () => null,
}));
vi.mock("../components/ScaleLoader/ScaleLoader", () => ({
ScaleLoader: () => <div data-testid="scale-loader" />,
}));
vi.mock("../components/ArtifactPanel/ArtifactPanel", () => ({
ArtifactPanel: () => null,
}));
vi.mock("@/components/ui/sidebar", () => ({
SidebarProvider: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
// Mock hooks that hit the network
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
useGetV2GetCopilotUsage: () => ({
data: undefined,
isSuccess: false,
isError: false,
}),
}));
vi.mock("@/hooks/useCredits", () => ({
default: () => ({ credits: null, fetchCredits: vi.fn() }),
}));
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: {
ENABLE_PLATFORM_PAYMENT: "ENABLE_PLATFORM_PAYMENT",
ARTIFACTS: "ARTIFACTS",
CHAT_MODE_OPTION: "CHAT_MODE_OPTION",
},
useGetFlag: () => false,
}));
// Build the base mock return value for useCopilotPage
const basePageState = {
sessionId: null as string | null,
messages: [],
status: "ready" as const,
error: undefined,
stop: vi.fn(),
isReconnecting: false,
isSyncing: false,
createSession: vi.fn(),
onSend: vi.fn(),
isLoadingSession: false,
isSessionError: false,
isCreatingSession: false,
isUploadingFiles: false,
isUserLoading: false,
isLoggedIn: true,
hasMoreMessages: false,
isLoadingMore: false,
loadMore: vi.fn(),
isMobile: false,
isDrawerOpen: false,
sessions: [],
isLoadingSessions: false,
handleOpenDrawer: vi.fn(),
handleCloseDrawer: vi.fn(),
handleDrawerOpenChange: vi.fn(),
handleSelectSession: vi.fn(),
handleNewChat: vi.fn(),
sessionToDelete: null,
isDeleting: false,
handleConfirmDelete: vi.fn(),
handleCancelDelete: vi.fn(),
historicalDurations: {},
rateLimitMessage: null,
dismissRateLimit: vi.fn(),
isDryRun: false,
sessionDryRun: false,
};
const mockUseCopilotPage = vi.fn(() => basePageState);
vi.mock("../useCopilotPage", () => ({
useCopilotPage: () => mockUseCopilotPage(),
}));
afterEach(() => {
cleanup();
mockUseCopilotPage.mockReset();
mockUseCopilotPage.mockImplementation(() => basePageState);
});
describe("CopilotPage test-mode banner", () => {
it("does not show test-mode banner when there is no active session", () => {
render(<CopilotPage />);
expect(
screen.queryByText(/test mode.*this session runs agents/i),
).toBeNull();
});
it("does not show test-mode banner when session exists but sessionDryRun is false", () => {
mockUseCopilotPage.mockReturnValue({
...basePageState,
sessionId: "session-abc",
sessionDryRun: false,
});
render(<CopilotPage />);
expect(
screen.queryByText(/test mode.*this session runs agents/i),
).toBeNull();
});
it("shows test-mode banner when session exists and sessionDryRun is true", () => {
mockUseCopilotPage.mockReturnValue({
...basePageState,
sessionId: "session-abc",
sessionDryRun: true,
});
render(<CopilotPage />);
expect(
screen.getByText(/test mode.*this session runs agents/i),
).toBeDefined();
});
it("does not show test-mode banner when sessionDryRun is true but no sessionId", () => {
mockUseCopilotPage.mockReturnValue({
...basePageState,
sessionId: null,
sessionDryRun: true,
});
render(<CopilotPage />);
expect(
screen.queryByText(/test mode.*this session runs agents/i),
).toBeNull();
});
it("shows loading spinner when user is loading", () => {
mockUseCopilotPage.mockReturnValue({
...basePageState,
isUserLoading: true,
isLoggedIn: false,
});
render(<CopilotPage />);
expect(screen.getByTestId("scale-loader")).toBeDefined();
expect(screen.queryByTestId("chat-container")).toBeNull();
});
});

View File

@@ -1,6 +1,11 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { IMPERSONATION_HEADER_NAME } from "@/lib/constants";
import { getCopilotAuthHeaders } from "../helpers";
import {
getCopilotAuthHeaders,
getSendSuppressionReason,
resolveSessionDryRun,
} from "../helpers";
import type { UIMessage } from "ai";
vi.mock("@/lib/supabase/actions", () => ({
getWebSocketToken: vi.fn(),
@@ -16,6 +21,42 @@ import { getSystemHeaders } from "@/lib/impersonation";
const mockGetWebSocketToken = vi.mocked(getWebSocketToken);
const mockGetSystemHeaders = vi.mocked(getSystemHeaders);
describe("resolveSessionDryRun", () => {
it("returns false when queryData is null", () => {
expect(resolveSessionDryRun(null)).toBe(false);
});
it("returns false when queryData is undefined", () => {
expect(resolveSessionDryRun(undefined)).toBe(false);
});
it("returns false when status is not 200", () => {
expect(resolveSessionDryRun({ status: 404 })).toBe(false);
});
it("returns false when status is 200 but metadata.dry_run is false", () => {
expect(
resolveSessionDryRun({
status: 200,
data: { metadata: { dry_run: false } },
}),
).toBe(false);
});
it("returns false when status is 200 but metadata is missing", () => {
expect(resolveSessionDryRun({ status: 200, data: {} })).toBe(false);
});
it("returns true when status is 200 and metadata.dry_run is true", () => {
expect(
resolveSessionDryRun({
status: 200,
data: { metadata: { dry_run: true } },
}),
).toBe(true);
});
});
describe("getCopilotAuthHeaders", () => {
beforeEach(() => {
vi.clearAllMocks();
@@ -72,3 +113,71 @@ describe("getCopilotAuthHeaders", () => {
);
});
});
// ─── getSendSuppressionReason ─────────────────────────────────────────────────
function makeUserMsg(text: string): UIMessage {
return {
id: "msg-1",
role: "user",
content: text,
parts: [{ type: "text", text }],
} as UIMessage;
}
describe("getSendSuppressionReason", () => {
it("returns null when no dedup context exists (fresh ref)", () => {
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: null,
messages: [],
});
expect(result).toBeNull();
});
it("returns 'reconnecting' when reconnect is scheduled regardless of text", () => {
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: true,
lastSubmittedText: null,
messages: [],
});
expect(result).toBe("reconnecting");
});
it("returns 'duplicate' when same text was submitted and is the last user message", () => {
// This is the core regression test: after a successful turn the ref
// is intentionally NOT cleared to null, so submitting the same text
// again is caught here.
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages: [makeUserMsg("hello")],
});
expect(result).toBe("duplicate");
});
it("returns null when same ref text but different last user message (different question)", () => {
// User asked "hello" before, got a reply, then asked a different question
// — the last user message in chat is now different, so no suppression.
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages: [makeUserMsg("hello"), makeUserMsg("something else")],
});
expect(result).toBeNull();
});
it("returns null when text differs from lastSubmittedText", () => {
const result = getSendSuppressionReason({
text: "new question",
isReconnectScheduled: false,
lastSubmittedText: "old question",
messages: [makeUserMsg("old question")],
});
expect(result).toBeNull();
});
});

View File

@@ -1,4 +1,4 @@
import { describe, expect, it, beforeEach, vi } from "vitest";
import { describe, expect, it, beforeEach, afterEach, vi } from "vitest";
import { useCopilotUIStore } from "../store";
vi.mock("@sentry/nextjs", () => ({
@@ -22,7 +22,8 @@ describe("useCopilotUIStore", () => {
isNotificationsEnabled: false,
isSoundEnabled: true,
showNotificationDialog: false,
copilotMode: "extended_thinking",
copilotChatMode: "extended_thinking",
copilotLlmModel: "standard",
});
});
@@ -154,35 +155,52 @@ describe("useCopilotUIStore", () => {
});
});
describe("copilotMode", () => {
describe("copilotChatMode", () => {
it("defaults to extended_thinking", () => {
expect(useCopilotUIStore.getState().copilotMode).toBe(
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
"extended_thinking",
);
});
it("sets mode to fast", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
expect(useCopilotUIStore.getState().copilotMode).toBe("fast");
useCopilotUIStore.getState().setCopilotChatMode("fast");
expect(useCopilotUIStore.getState().copilotChatMode).toBe("fast");
});
it("sets mode back to extended_thinking", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
useCopilotUIStore.getState().setCopilotMode("extended_thinking");
expect(useCopilotUIStore.getState().copilotMode).toBe(
useCopilotUIStore.getState().setCopilotChatMode("fast");
useCopilotUIStore.getState().setCopilotChatMode("extended_thinking");
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
"extended_thinking",
);
});
it("does not persist mode to localStorage", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
it("persists mode to localStorage", () => {
useCopilotUIStore.getState().setCopilotChatMode("fast");
expect(window.localStorage.getItem("copilot-mode")).toBe("fast");
});
});
describe("copilotLlmModel", () => {
it("defaults to standard", () => {
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("standard");
});
it("sets model to advanced", () => {
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("advanced");
});
it("persists model to localStorage", () => {
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
expect(window.localStorage.getItem("copilot-model")).toBe("advanced");
});
});
describe("clearCopilotLocalData", () => {
it("resets state and clears localStorage keys", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
useCopilotUIStore.getState().setCopilotChatMode("fast");
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
useCopilotUIStore.getState().setNotificationsEnabled(true);
useCopilotUIStore.getState().toggleSound();
useCopilotUIStore.getState().addCompletedSession("s1");
@@ -190,7 +208,8 @@ describe("useCopilotUIStore", () => {
useCopilotUIStore.getState().clearCopilotLocalData();
const state = useCopilotUIStore.getState();
expect(state.copilotMode).toBe("extended_thinking");
expect(state.copilotChatMode).toBe("extended_thinking");
expect(state.copilotLlmModel).toBe("standard");
expect(state.isNotificationsEnabled).toBe(false);
expect(state.isSoundEnabled).toBe(true);
expect(state.completedSessionIDs.size).toBe(0);
@@ -198,6 +217,8 @@ describe("useCopilotUIStore", () => {
window.localStorage.getItem("copilot-notifications-enabled"),
).toBeNull();
expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull();
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
expect(window.localStorage.getItem("copilot-model")).toBeNull();
expect(
window.localStorage.getItem("copilot-completed-sessions"),
).toBeNull();
@@ -222,3 +243,24 @@ describe("useCopilotUIStore", () => {
});
});
});
describe("useCopilotUIStore localStorage initialisation", () => {
afterEach(() => {
vi.resetModules();
window.localStorage.clear();
});
it("reads fast chat mode from localStorage on store creation", async () => {
window.localStorage.setItem("copilot-mode", "fast");
vi.resetModules();
const { useCopilotUIStore: fresh } = await import("../store");
expect(fresh.getState().copilotChatMode).toBe("fast");
});
it("reads advanced model from localStorage on store creation", async () => {
window.localStorage.setItem("copilot-model", "advanced");
vi.resetModules();
const { useCopilotUIStore: fresh } = await import("../store");
expect(fresh.getState().copilotLlmModel).toBe("advanced");
});
});

View File

@@ -0,0 +1,145 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { ArtifactCard } from "./ArtifactCard";
import type { ArtifactRef } from "../../store";
import { useCopilotUIStore } from "../../store";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "report.html",
mimeType: "text/html",
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
origin: "agent",
...overrides,
};
}
const meta: Meta<typeof ArtifactCard> = {
title: "Copilot/ArtifactCard",
component: ArtifactCard,
tags: ["autodocs"],
parameters: {
layout: "padded",
docs: {
description: {
component:
"Inline artifact card rendered in chat messages. Openable artifacts show a caret and open the ArtifactPanel on click. Download-only artifacts trigger a file download.",
},
},
},
decorators: [
(Story) => (
<div className="w-96">
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof meta>;
export const OpenableHTML: Story = {
name: "Openable (HTML)",
args: {
artifact: makeArtifact({
title: "dashboard.html",
mimeType: "text/html",
}),
},
};
export const OpenableImage: Story = {
name: "Openable (Image)",
args: {
artifact: makeArtifact({
id: "img-card",
title: "chart.png",
mimeType: "image/png",
}),
},
};
export const OpenableCode: Story = {
name: "Openable (Code)",
args: {
artifact: makeArtifact({
title: "script.py",
mimeType: "text/x-python",
}),
},
};
export const DownloadOnly: Story = {
name: "Download Only (ZIP)",
args: {
artifact: makeArtifact({
title: "archive.zip",
mimeType: "application/zip",
sizeBytes: 2_500_000,
}),
},
};
export const PreviewableVideo: Story = {
name: "Previewable (Video)",
args: {
artifact: makeArtifact({
title: "demo.mp4",
mimeType: "video/mp4",
sizeBytes: 15_000_000,
}),
},
parameters: {
docs: {
description: {
story:
"Videos with supported formats (MP4, WebM, M4V) are previewable inline in the artifact panel.",
},
},
},
};
export const WithSize: Story = {
name: "With File Size",
args: {
artifact: makeArtifact({
title: "data.csv",
mimeType: "text/csv",
sizeBytes: 524_288,
}),
},
};
export const UserUpload: Story = {
name: "User Upload Origin",
args: {
artifact: makeArtifact({
title: "requirements.txt",
mimeType: "text/plain",
origin: "user-upload",
}),
},
};
export const ActiveState: Story = {
name: "Active (Panel Open)",
args: {
artifact: makeArtifact({ id: "active-card" }),
},
decorators: [
(Story) => {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: makeArtifact({ id: "active-card" }),
history: [],
},
});
return <Story />;
},
],
};

View File

@@ -0,0 +1,223 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { http, HttpResponse } from "msw";
import { ArtifactPanel } from "./ArtifactPanel";
import { useCopilotUIStore } from "../../store";
import type { ArtifactRef } from "../../store";
const PROXY_BASE = "/api/proxy/api/workspace/files";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "report.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/file-001/download`,
origin: "agent",
...overrides,
};
}
function openPanelWith(artifact: ArtifactRef) {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: artifact,
history: [],
},
});
}
const meta: Meta<typeof ArtifactPanel> = {
title: "Copilot/ArtifactPanel",
component: ArtifactPanel,
tags: ["autodocs"],
parameters: {
layout: "fullscreen",
docs: {
description: {
component:
"Side panel for previewing workspace artifacts. Supports resize, minimize, maximize, and navigation history. Bug: panel auto-opens on chat switch instead of staying collapsed.",
},
},
},
decorators: [
(Story) => (
<div className="flex h-[600px] w-full">
<div className="flex-1 bg-zinc-50 p-8">
<p className="text-sm text-zinc-500">Chat area</p>
</div>
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof meta>;
export const OpenWithTextArtifact: Story = {
name: "Open — Text File",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({ title: "notes.txt", mimeType: "text/plain" }),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/file-001/download`, () => {
return HttpResponse.text(
"These are some notes from the agent execution.\n\nKey findings:\n1. Performance improved by 23%\n2. Memory usage reduced\n3. Error rate dropped to 0.1%",
);
}),
],
},
},
};
export const OpenWithHTMLArtifact: Story = {
name: "Open — HTML",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({
id: "html-panel",
title: "dashboard.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/html-panel/download`,
}),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/html-panel/download`, () => {
return HttpResponse.text(
`<!DOCTYPE html><html><body class="p-8 font-sans"><h1 class="text-2xl font-bold text-indigo-600">Dashboard</h1><p class="mt-2 text-gray-600">HTML artifact in the panel.</p></body></html>`,
);
}),
],
},
},
};
export const OpenWithImageArtifact: Story = {
name: "Open — Image (Bug: No Loading State)",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({
id: "img-panel",
title: "chart.png",
mimeType: "image/png",
sourceUrl: `${PROXY_BASE}/img-panel/download`,
}),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/img-panel/download`, () => {
return HttpResponse.text(
'<svg xmlns="http://www.w3.org/2000/svg" width="500" height="300"><rect width="500" height="300" fill="#dbeafe"/><text x="250" y="150" text-anchor="middle" fill="#1e40af" font-size="20">Image Preview (no skeleton)</text></svg>',
{ headers: { "Content-Type": "image/svg+xml" } },
);
}),
],
},
docs: {
description: {
story:
"**BUG:** Image artifacts render with a bare `<img>` tag — no loading skeleton or error handling. Compare with text/HTML artifacts which show a proper skeleton while loading.",
},
},
},
};
export const MinimizedStrip: Story = {
name: "Minimized",
decorators: [
(Story) => {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: true,
isMaximized: false,
width: 600,
activeArtifact: makeArtifact(),
history: [],
},
});
return <Story />;
},
],
};
export const ErrorState: Story = {
name: "Error — Failed to Load (Stale Artifact)",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({
id: "stale-panel",
title: "old-report.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/stale-panel/download`,
}),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/stale-panel/download`, () => {
return new HttpResponse(null, { status: 404 });
}),
],
},
docs: {
description: {
story:
"Shows what users see when opening a previously generated artifact that no longer exists on the backend (404). The 'Try again' button retries the fetch.",
},
},
},
};
export const Closed: Story = {
name: "Closed (Default State)",
decorators: [
(Story) => {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: false,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: null,
history: [],
},
});
return <Story />;
},
],
parameters: {
docs: {
description: {
story:
"The default state — panel is closed. It should only open when a user clicks on an artifact card in the chat.",
},
},
},
};

View File

@@ -0,0 +1,413 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { downloadArtifact } from "../downloadArtifact";
import type { ArtifactRef } from "../../../store";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "report.pdf",
mimeType: "application/pdf",
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
origin: "agent",
...overrides,
};
}
describe("downloadArtifact", () => {
let clickSpy: ReturnType<typeof vi.fn>;
let removeSpy: ReturnType<typeof vi.fn>;
beforeEach(() => {
clickSpy = vi.fn();
removeSpy = vi.fn();
vi.stubGlobal(
"URL",
Object.assign(URL, {
createObjectURL: vi.fn().mockReturnValue("blob:fake-url"),
revokeObjectURL: vi.fn(),
}),
);
vi.spyOn(document, "createElement").mockReturnValue({
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
} as unknown as HTMLAnchorElement);
vi.spyOn(document.body, "appendChild").mockImplementation(
(node) => node as ChildNode,
);
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it("downloads file successfully on 200 response", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["pdf content"])),
}),
);
await downloadArtifact(makeArtifact());
expect(fetch).toHaveBeenCalledWith(
"/api/proxy/api/workspace/files/file-001/download",
);
expect(clickSpy).toHaveBeenCalled();
expect(removeSpy).toHaveBeenCalled();
expect(URL.revokeObjectURL).toHaveBeenCalledWith("blob:fake-url");
});
it("rejects on persistent server error after exhausting retries", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 500,
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 500",
);
expect(clickSpy).not.toHaveBeenCalled();
});
it("rejects on persistent network error after exhausting retries", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.reject(new Error("Network error"));
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Network error",
);
expect(callCount).toBe(3);
expect(clickSpy).not.toHaveBeenCalled();
});
it("retries on transient network error and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.reject(new Error("Connection reset"));
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
it("retries on transient 500 and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({ ok: false, status: 500 });
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
// Should succeed on second attempt
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
it("sanitizes dangerous filenames", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: "../../../etc/passwd" }));
expect(anchor.download).not.toContain("..");
expect(anchor.download).not.toContain("/");
});
// ── Transient retry codes ─────────────────────────────────────────
it("retries on 408 (Request Timeout) and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({ ok: false, status: 408 });
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
it("retries on 429 (Too Many Requests) and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({ ok: false, status: 429 });
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
// ── Non-transient errors ──────────────────────────────────────────
it("rejects immediately on 403 (non-transient) without retry", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({ ok: false, status: 403 });
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 403",
);
expect(callCount).toBe(1);
expect(clickSpy).not.toHaveBeenCalled();
});
it("rejects immediately on 404 without retry", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({ ok: false, status: 404 });
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 404",
);
expect(callCount).toBe(1);
});
// ── Exhausted retries ─────────────────────────────────────────────
it("rejects after exhausting all retries on persistent 500", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({ ok: false, status: 500 });
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 500",
);
// Initial attempt + 2 retries = 3 total
expect(callCount).toBe(3);
expect(clickSpy).not.toHaveBeenCalled();
});
// ── Filename edge cases ───────────────────────────────────────────
it("falls back to 'download' when title is empty", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: "" }));
expect(anchor.download).toBe("download");
});
it("falls back to 'download' when title is only dots", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
// Dot-only names should not produce a hidden or empty filename.
await downloadArtifact(makeArtifact({ title: "...." }));
expect(anchor.download).toBe("download");
});
it("replaces special chars with underscores (not empty)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: '***???"' }));
// Special chars become underscores, not removed
expect(anchor.download).toBe("_______");
});
it("strips leading dots from filename", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: "...hidden.txt" }));
expect(anchor.download).not.toMatch(/^\./);
expect(anchor.download).toContain("hidden.txt");
});
it("replaces Windows-reserved characters", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(
makeArtifact({ title: "file<name>with:bad*chars?.txt" }),
);
expect(anchor.download).not.toMatch(/[<>:*?]/);
});
it("replaces control characters in filename", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(
makeArtifact({ title: "file\x00with\x1fcontrol.txt" }),
);
expect(anchor.download).not.toMatch(/[\x00-\x1f]/);
});
});

View File

@@ -0,0 +1,460 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { http, HttpResponse } from "msw";
import { ArtifactContent } from "./ArtifactContent";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import {
Code,
File,
FileHtml,
FileText,
Image,
Table,
} from "@phosphor-icons/react";
const PROXY_BASE = "/api/proxy/api/workspace/files";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "test.txt",
mimeType: "text/plain",
sourceUrl: `${PROXY_BASE}/file-001/download`,
origin: "agent",
...overrides,
};
}
function makeClassification(
overrides?: Partial<ArtifactClassification>,
): ArtifactClassification {
return {
type: "text",
icon: FileText,
label: "Text",
openable: true,
hasSourceToggle: false,
...overrides,
};
}
const meta: Meta<typeof ArtifactContent> = {
title: "Copilot/ArtifactContent",
component: ArtifactContent,
tags: ["autodocs"],
parameters: {
layout: "padded",
docs: {
description: {
component:
"Renders artifact content based on file type classification. Supports images, HTML, code, CSV, JSON, markdown, PDF, and plain text. Bug: image artifacts render as bare <img> with no loading/error states.",
},
},
},
decorators: [
(Story) => (
<div
className="flex h-[500px] w-[600px] flex-col overflow-hidden border border-zinc-200"
style={{ resize: "both" }}
>
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof meta>;
export const ImageArtifactPNG: Story = {
name: "Image (PNG) — No Loading Skeleton (Bug #1)",
args: {
artifact: makeArtifact({
id: "img-png",
title: "chart.png",
mimeType: "image/png",
sourceUrl: `${PROXY_BASE}/img-png/download`,
}),
isSourceView: false,
classification: makeClassification({ type: "image", icon: Image }),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/img-png/download`, () => {
return HttpResponse.text(
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#e0e7ff"/><text x="200" y="150" text-anchor="middle" fill="#4338ca" font-size="24">PNG Placeholder</text></svg>',
{ headers: { "Content-Type": "image/svg+xml" } },
);
}),
],
},
docs: {
description: {
story:
"**BUG:** This renders a bare `<img>` tag with no loading skeleton or error handling. Compare with WorkspaceFileRenderer which has proper Skeleton + onError states.",
},
},
},
};
export const ImageArtifactSVG: Story = {
name: "Image (SVG)",
args: {
artifact: makeArtifact({
id: "img-svg",
title: "diagram.svg",
mimeType: "image/svg+xml",
sourceUrl: `${PROXY_BASE}/img-svg/download`,
}),
isSourceView: false,
classification: makeClassification({ type: "image", icon: Image }),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/img-svg/download`, () => {
return HttpResponse.text(
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#fef3c7"/><circle cx="200" cy="150" r="80" fill="#f59e0b"/><text x="200" y="155" text-anchor="middle" fill="white" font-size="20">SVG OK</text></svg>',
{ headers: { "Content-Type": "image/svg+xml" } },
);
}),
],
},
},
};
export const HTMLArtifact: Story = {
name: "HTML",
args: {
artifact: makeArtifact({
id: "html-001",
title: "page.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/html-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "html",
icon: FileHtml,
label: "HTML",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/html-001/download`, () => {
return HttpResponse.text(
`<!DOCTYPE html>
<html>
<head><title>Artifact Preview</title></head>
<body class="p-8 font-sans">
<h1 class="text-2xl font-bold text-indigo-600 mb-4">HTML Artifact</h1>
<p class="text-gray-700">This is an HTML artifact rendered in a sandboxed iframe with Tailwind CSS injected.</p>
<div class="mt-4 p-4 bg-blue-50 rounded-lg border border-blue-200">
<p class="text-blue-800">Interactive content works via allow-scripts sandbox.</p>
</div>
</body>
</html>`,
{ headers: { "Content-Type": "text/html" } },
);
}),
],
},
},
};
export const CodeArtifact: Story = {
name: "Code (Python)",
args: {
artifact: makeArtifact({
id: "code-001",
title: "analysis.py",
mimeType: "text/x-python",
sourceUrl: `${PROXY_BASE}/code-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "code",
icon: Code,
label: "Code",
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/code-001/download`, () => {
return HttpResponse.text(
`import pandas as pd
import matplotlib.pyplot as plt
def analyze_data(filepath: str) -> pd.DataFrame:
"""Load and analyze CSV data."""
df = pd.read_csv(filepath)
summary = df.describe()
print(f"Loaded {len(df)} rows")
return summary
if __name__ == "__main__":
result = analyze_data("data.csv")
print(result)`,
{ headers: { "Content-Type": "text/plain" } },
);
}),
],
},
},
};
export const CSVArtifact: Story = {
name: "CSV (Spreadsheet)",
args: {
artifact: makeArtifact({
id: "csv-001",
title: "data.csv",
mimeType: "text/csv",
sourceUrl: `${PROXY_BASE}/csv-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "csv",
icon: Table,
label: "Spreadsheet",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/csv-001/download`, () => {
return HttpResponse.text(
`Name,Age,City,Score
Alice,28,New York,92
Bob,35,San Francisco,87
Charlie,22,Chicago,95
Diana,31,Boston,88
Eve,27,Seattle,91`,
{ headers: { "Content-Type": "text/csv" } },
);
}),
],
},
},
};
export const JSONArtifact: Story = {
name: "JSON (Data)",
args: {
artifact: makeArtifact({
id: "json-001",
title: "config.json",
mimeType: "application/json",
sourceUrl: `${PROXY_BASE}/json-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "json",
icon: Code,
label: "Data",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/json-001/download`, () => {
return HttpResponse.text(
JSON.stringify(
{
name: "AutoGPT Agent",
version: "2.0",
capabilities: ["web_search", "code_execution", "file_io"],
settings: { maxTokens: 4096, temperature: 0.7 },
},
null,
2,
),
{ headers: { "Content-Type": "application/json" } },
);
}),
],
},
},
};
export const MarkdownArtifact: Story = {
name: "Markdown",
args: {
artifact: makeArtifact({
id: "md-001",
title: "README.md",
mimeType: "text/markdown",
sourceUrl: `${PROXY_BASE}/md-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "markdown",
icon: FileText,
label: "Document",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/md-001/download`, () => {
return HttpResponse.text(
`# Project Summary
## Overview
This is a **markdown** artifact rendered through the global renderer registry.
## Features
- Headings and paragraphs
- **Bold** and *italic* text
- Lists and code blocks
\`\`\`python
print("Hello from markdown!")
\`\`\`
> Blockquotes are also supported.`,
{ headers: { "Content-Type": "text/plain" } },
);
}),
],
},
},
};
export const PDFArtifact: Story = {
name: "PDF",
args: {
artifact: makeArtifact({
id: "pdf-001",
title: "report.pdf",
mimeType: "application/pdf",
sourceUrl: `${PROXY_BASE}/pdf-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "pdf",
icon: FileText,
label: "PDF",
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/pdf-001/download`, () => {
return HttpResponse.arrayBuffer(new ArrayBuffer(100), {
headers: { "Content-Type": "application/pdf" },
});
}),
],
},
docs: {
description: {
story:
"PDF artifacts are rendered in an unsandboxed iframe using a blob URL (Chromium bug #413851 prevents sandboxed PDF rendering).",
},
},
},
};
export const ErrorState: Story = {
name: "Error — Failed to Load Content",
args: {
artifact: makeArtifact({
id: "error-001",
title: "old-report.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/error-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "html",
icon: FileHtml,
label: "HTML",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/error-001/download`, () => {
return new HttpResponse(null, { status: 404 });
}),
],
},
docs: {
description: {
story:
"Shows the error state when an artifact fails to load (e.g., old/expired file returning 404). Includes a 'Try again' retry button.",
},
},
},
};
export const LoadingSkeleton: Story = {
name: "Loading State",
args: {
artifact: makeArtifact({
id: "loading-001",
title: "loading.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/loading-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "html",
icon: FileHtml,
label: "HTML",
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/loading-001/download`, async () => {
// Delay response to show loading state
await new Promise((r) => setTimeout(r, 999999));
return HttpResponse.text("never resolves");
}),
],
},
docs: {
description: {
story:
"Shows the skeleton loading state while content is being fetched.",
},
},
},
};
export const DownloadOnly: Story = {
name: "Download Only (Binary)",
args: {
artifact: makeArtifact({
id: "bin-001",
title: "archive.zip",
mimeType: "application/zip",
sourceUrl: `${PROXY_BASE}/bin-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "download-only",
icon: File,
label: "File",
openable: false,
}),
},
parameters: {
docs: {
description: {
story:
"Download-only files (binary, video, etc.) are not rendered inline. The ArtifactPanel shows nothing for these — they are handled by ArtifactCard with a download button.",
},
},
},
};

View File

@@ -2,7 +2,8 @@
import { globalRegistry } from "@/components/contextual/OutputRenderers";
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
import { Suspense } from "react";
import { Suspense, useState } from "react";
import { Skeleton } from "@/components/ui/skeleton";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
@@ -63,6 +64,90 @@ function ArtifactContentLoader({
);
}
function ArtifactImage({ src, alt }: { src: string; alt: string }) {
const [loaded, setLoaded] = useState(false);
const [error, setError] = useState(false);
if (error) {
return (
<div
role="alert"
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm text-zinc-500">Failed to load image</p>
<button
type="button"
onClick={() => {
setError(false);
setLoaded(false);
}}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Try again
</button>
</div>
);
}
return (
<div className="relative flex items-center justify-center p-4">
{!loaded && (
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
)}
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={src}
alt={alt}
className={`max-h-full max-w-full object-contain transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
onLoad={() => setLoaded(true)}
onError={() => setError(true)}
/>
</div>
);
}
function ArtifactVideo({ src }: { src: string }) {
const [loaded, setLoaded] = useState(false);
const [error, setError] = useState(false);
if (error) {
return (
<div
role="alert"
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm text-zinc-500">Failed to load video</p>
<button
type="button"
onClick={() => {
setError(false);
setLoaded(false);
}}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Try again
</button>
</div>
);
}
return (
<div className="relative flex items-center justify-center p-4">
{!loaded && (
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
)}
<video
src={src}
controls
preload="metadata"
className={`max-h-full max-w-full rounded-md transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
onLoadedMetadata={() => setLoaded(true)}
onError={() => setError(true)}
/>
</div>
);
}
function ArtifactRenderer({
artifact,
content,
@@ -79,17 +164,19 @@ function ArtifactRenderer({
// Image: render directly from URL (no content fetch)
if (classification.type === "image") {
return (
<div className="flex items-center justify-center p-4">
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={artifact.sourceUrl}
alt={artifact.title}
className="max-h-full max-w-full object-contain"
/>
</div>
<ArtifactImage
key={artifact.sourceUrl}
src={artifact.sourceUrl}
alt={artifact.title}
/>
);
}
// Video: render with <video> controls (no content fetch)
if (classification.type === "video") {
return <ArtifactVideo key={artifact.sourceUrl} src={artifact.sourceUrl} />;
}
if (classification.type === "pdf" && pdfUrl) {
// No sandbox — Chrome/Edge block PDF rendering in sandboxed iframes
// (Chromium bug #413851). The blob URL has a null origin so it can't
@@ -164,7 +251,16 @@ function ArtifactRenderer({
// CSV: pass with explicit metadata so CSVRenderer matches
if (classification.type === "csv") {
const csvMeta = { mimeType: "text/csv", filename: artifact.title };
const normalizedMime = artifact.mimeType
?.toLowerCase()
.split(";")[0]
?.trim();
const csvMimeType =
normalizedMime === "text/tab-separated-values" ||
artifact.title.toLowerCase().endsWith(".tsv")
? "text/tab-separated-values"
: "text/csv";
const csvMeta = { mimeType: csvMimeType, filename: artifact.title };
const csvRenderer = globalRegistry.getRenderer(content, csvMeta);
if (csvRenderer) {
return <div className="p-4">{csvRenderer.render(content, csvMeta)}</div>;

View File

@@ -0,0 +1,67 @@
import { render, screen, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
import {
buildReactArtifactSrcDoc,
collectPreviewStyles,
transpileReactArtifactSource,
} from "./reactArtifactPreview";
vi.mock("./reactArtifactPreview", () => ({
buildReactArtifactSrcDoc: vi.fn(),
collectPreviewStyles: vi.fn(),
transpileReactArtifactSource: vi.fn(),
}));
describe("ArtifactReactPreview", () => {
beforeEach(() => {
vi.mocked(collectPreviewStyles).mockReturnValue("<style>preview</style>");
vi.mocked(buildReactArtifactSrcDoc).mockReturnValue("<html>preview</html>");
});
it("renders an iframe preview after transpilation succeeds", async () => {
vi.mocked(transpileReactArtifactSource).mockResolvedValue(
"module.exports.default = function Artifact() { return null; };",
);
const { container } = render(
<ArtifactReactPreview
source="export default function Artifact() { return null; }"
title="Artifact.tsx"
/>,
);
await waitFor(() => {
expect(buildReactArtifactSrcDoc).toHaveBeenCalledWith(
"module.exports.default = function Artifact() { return null; };",
"Artifact.tsx",
"<style>preview</style>",
);
});
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
expect(iframe?.getAttribute("title")).toBe("Artifact.tsx preview");
expect(iframe?.getAttribute("srcdoc")).toBe("<html>preview</html>");
});
it("shows a readable error when transpilation fails", async () => {
vi.mocked(transpileReactArtifactSource).mockRejectedValue(
new Error("Transpile exploded"),
);
render(
<ArtifactReactPreview
source="export default function Artifact() {"
title="Broken.tsx"
/>,
);
await waitFor(() => {
expect(screen.getByText("Failed to render React preview")).toBeTruthy();
});
expect(screen.getByText("Transpile exploded")).toBeTruthy();
});
});

View File

@@ -0,0 +1,970 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import {
cleanup,
fireEvent,
render,
screen,
waitFor,
} from "@testing-library/react";
import { ArtifactContent } from "../ArtifactContent";
import type { ArtifactRef } from "../../../../store";
import { classifyArtifact, type ArtifactClassification } from "../../helpers";
import { globalRegistry } from "@/components/contextual/OutputRenderers";
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
import { ArtifactReactPreview } from "../ArtifactReactPreview";
// Mock the renderers so we don't pull in the full renderer dependency tree
vi.mock("@/components/contextual/OutputRenderers", () => ({
globalRegistry: {
getRenderer: vi.fn().mockReturnValue({
render: vi.fn((_val: unknown, meta: Record<string, unknown>) => (
<div data-testid="global-renderer">
rendered:{String(meta?.mimeType ?? "unknown")}
</div>
)),
}),
},
}));
vi.mock(
"@/components/contextual/OutputRenderers/renderers/CodeRenderer",
() => ({
codeRenderer: {
render: vi.fn((content: string) => (
<div data-testid="code-renderer">{content}</div>
)),
},
}),
);
vi.mock("../ArtifactReactPreview", () => ({
ArtifactReactPreview: vi.fn(
({ source, title }: { source: string; title: string }) => (
<div data-testid="react-preview" data-title={title}>
{source}
</div>
),
),
}));
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "test.txt",
mimeType: "text/plain",
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
origin: "agent",
...overrides,
};
}
function makeClassification(
overrides?: Partial<ArtifactClassification>,
): ArtifactClassification {
return {
type: "text",
icon: vi.fn(() => null) as unknown as ArtifactClassification["icon"],
label: "Text",
openable: true,
hasSourceToggle: false,
...overrides,
};
}
describe("ArtifactContent", () => {
beforeEach(() => {
vi.clearAllMocks();
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("file content here"),
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
});
afterEach(() => {
cleanup();
vi.unstubAllGlobals();
});
// ── Image ─────────────────────────────────────────────────────────
it("renders image artifact as img tag with loading skeleton", () => {
const artifact = makeArtifact({
id: "img-001",
title: "photo.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-001/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
expect(img).toBeTruthy();
expect(img?.getAttribute("src")).toBe(
"/api/proxy/api/workspace/files/img-001/download",
);
expect(fetch).not.toHaveBeenCalled();
});
it("image artifact shows loading skeleton before image loads", () => {
const artifact = makeArtifact({
id: "img-skeleton",
title: "photo.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-skeleton/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
// Skeleton uses animate-pulse class
const skeleton = container.querySelector('[class*="animate-pulse"]');
expect(skeleton).toBeTruthy();
});
it("image artifact shows error state when image fails to load", () => {
const artifact = makeArtifact({
id: "img-error",
title: "broken.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-error/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
expect(img).toBeTruthy();
fireEvent.error(img!);
const errorAlert = screen.queryByRole("alert");
expect(errorAlert).toBeTruthy();
expect(screen.queryByText("Failed to load image")).toBeTruthy();
});
it("image retry resets error and re-shows img", async () => {
const artifact = makeArtifact({
id: "img-retry",
title: "retry.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-retry/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
fireEvent.error(img!);
// Should show error state
await waitFor(() => {
expect(screen.queryByText("Failed to load image")).toBeTruthy();
});
// Click "Try again"
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
// Error should be cleared, img should reappear
await waitFor(() => {
expect(screen.queryByText("Failed to load image")).toBeNull();
expect(container.querySelector("img")).toBeTruthy();
});
});
// ── Video ─────────────────────────────────────────────────────────
it("renders video artifact with video tag and controls", () => {
const artifact = makeArtifact({
id: "vid-001",
title: "clip.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-001/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const video = container.querySelector("video");
expect(video).toBeTruthy();
expect(video?.hasAttribute("controls")).toBe(true);
expect(video?.getAttribute("src")).toBe(
"/api/proxy/api/workspace/files/vid-001/download",
);
expect(fetch).not.toHaveBeenCalled();
});
it("video shows loading skeleton before metadata loads", () => {
const artifact = makeArtifact({
id: "vid-skel",
title: "clip.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-skel/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const skeleton = container.querySelector('[class*="animate-pulse"]');
expect(skeleton).toBeTruthy();
// After metadata loads, skeleton should disappear
const video = container.querySelector("video");
fireEvent.loadedMetadata(video!);
expect(container.querySelector('[class*="animate-pulse"]')).toBeNull();
});
it("video shows error state when video fails to load", () => {
const artifact = makeArtifact({
id: "vid-error",
title: "broken.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-error/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const video = container.querySelector("video");
expect(video).toBeTruthy();
fireEvent.error(video!);
const errorAlert = screen.queryByRole("alert");
expect(errorAlert).toBeTruthy();
expect(screen.queryByText("Failed to load video")).toBeTruthy();
});
it("video retry resets error and re-shows video", async () => {
const artifact = makeArtifact({
id: "vid-retry",
title: "retry.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-retry/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const video = container.querySelector("video");
fireEvent.error(video!);
await waitFor(() => {
expect(screen.queryByText("Failed to load video")).toBeTruthy();
});
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
await waitFor(() => {
expect(screen.queryByText("Failed to load video")).toBeNull();
expect(container.querySelector("video")).toBeTruthy();
});
});
// ── PDF ───────────────────────────────────────────────────────────
it("renders PDF artifact in unsandboxed iframe with blob URL", async () => {
const blobUrl = "blob:http://localhost/fake-pdf-blob";
vi.stubGlobal(
"URL",
Object.assign(URL, {
createObjectURL: vi.fn().mockReturnValue(blobUrl),
revokeObjectURL: vi.fn(),
}),
);
const artifact = makeArtifact({
id: "pdf-render",
title: "report.pdf",
mimeType: "application/pdf",
sourceUrl: "/api/proxy/api/workspace/files/pdf-render/download",
});
const classification = makeClassification({ type: "pdf" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("src")).toBe(blobUrl);
// No sandbox attribute — Chrome blocks PDF in sandboxed iframes
expect(iframe?.hasAttribute("sandbox")).toBe(false);
});
});
// ── Fetch error ───────────────────────────────────────────────────
it("shows error state with retry button on fetch failure", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
}),
);
const artifact = makeArtifact({ id: "error-content-test" });
const classification = makeClassification({ type: "html" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const errorText = await screen.findByText("Failed to load content");
expect(errorText).toBeTruthy();
const retryButtons = screen.getAllByRole("button", { name: /try again/i });
expect(retryButtons.length).toBeGreaterThan(0);
});
// ── HTML ──────────────────────────────────────────────────────────
it("renders HTML content in sandboxed iframe", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () =>
Promise.resolve("<html><body><h1>Hello World</h1></body></html>"),
}),
);
const artifact = makeArtifact({
id: "html-001",
title: "page.html",
mimeType: "text/html",
});
const classification = makeClassification({ type: "html" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByTitle("page.html");
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
});
// ── Source view ───────────────────────────────────────────────────
it("renders source view as pre tag", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source code here"),
}),
);
const artifact = makeArtifact({ id: "source-view-test" });
const classification = makeClassification({
type: "html",
hasSourceToggle: true,
});
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={true}
classification={classification}
/>,
);
await screen.findByText("source code here");
const pre = container.querySelector("pre");
expect(pre).toBeTruthy();
expect(pre?.textContent).toBe("source code here");
});
// ── React ─────────────────────────────────────────────────────────
it("renders react artifacts via ArtifactReactPreview", async () => {
const jsxSource = "export default function App() { return <div>Hi</div>; }";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(jsxSource),
}),
);
const artifact = makeArtifact({
id: "react-001",
title: "App.tsx",
mimeType: "text/tsx",
});
const classification = makeClassification({ type: "react" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const preview = await screen.findByTestId("react-preview");
expect(preview).toBeTruthy();
expect(preview.textContent).toContain(jsxSource);
expect(preview.getAttribute("data-title")).toBe("App.tsx");
});
it("routes a concrete props-based TSX artifact into ArtifactReactPreview", async () => {
const jsxSource = `
import React, { FC, useState } from "react";
interface ArtifactFile {
id: string;
name: string;
mimeType: string;
url: string;
sizeBytes: number;
}
interface Props {
files: ArtifactFile[];
onSelect: (file: ArtifactFile) => void;
}
export const previewProps: Props = {
files: [
{
id: "1",
name: "report.png",
mimeType: "image/png",
url: "/report.png",
sizeBytes: 2048,
},
],
onSelect: () => {},
};
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
const [selected, setSelected] = useState<string | null>(null);
const handleClick = (file: ArtifactFile) => {
setSelected(file.id);
onSelect(file);
};
return (
<ul>
{files.map((file) => (
<li key={file.id} onClick={() => handleClick(file)}>
<span>{selected === file.id ? "selected" : file.name}</span>
</li>
))}
</ul>
);
};
export default ArtifactList;
`;
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(jsxSource),
}),
);
const artifact = makeArtifact({
id: "react-props-001",
title: "ArtifactList.tsx",
mimeType: "text/tsx",
});
const classification = classifyArtifact(artifact.mimeType, artifact.title);
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const preview = await screen.findByTestId("react-preview");
expect(preview.textContent).toContain("previewProps");
expect(preview.getAttribute("data-title")).toBe("ArtifactList.tsx");
expect(vi.mocked(ArtifactReactPreview).mock.calls[0]?.[0]).toEqual(
expect.objectContaining({
source: expect.stringContaining("export const previewProps"),
title: "ArtifactList.tsx",
}),
);
});
// ── Code ──────────────────────────────────────────────────────────
it("renders code artifacts via codeRenderer", async () => {
const code = 'def hello():\n print("hi")';
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(code),
}),
);
const artifact = makeArtifact({
id: "code-render-001",
title: "script.py",
mimeType: "text/x-python",
});
const classification = makeClassification({ type: "code" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("code-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain(code);
});
it.each([
{
filename: "events.jsonl",
mimeType: "application/x-ndjson",
content: '{"event":"start"}\n{"event":"finish"}',
},
{
filename: ".env.local",
mimeType: "text/plain",
content: "OPENAI_API_KEY=test\nDEBUG=true",
},
{
filename: "Dockerfile",
mimeType: "text/plain",
content: "FROM node:20\nRUN pnpm install",
},
{
filename: "schema.graphql",
mimeType: "text/plain",
content: "type Query { viewer: User }",
},
])(
"renders concrete code artifact $filename through codeRenderer",
async ({ filename, mimeType, content }) => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(content),
}),
);
const artifact = makeArtifact({
id: `code-${filename}`,
title: filename,
mimeType,
});
const classification = classifyArtifact(
artifact.mimeType,
artifact.title,
);
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByTestId("code-renderer");
expect(classification.type).toBe("code");
expect(vi.mocked(codeRenderer.render)).toHaveBeenCalledWith(
content,
expect.objectContaining({
filename,
mimeType,
type: "code",
}),
);
},
);
// ── JSON ──────────────────────────────────────────────────────────
it("renders valid JSON via globalRegistry", async () => {
const jsonContent = JSON.stringify({ key: "value" }, null, 2);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(jsonContent),
}),
);
const artifact = makeArtifact({
id: "json-render-001",
title: "data.json",
mimeType: "application/json",
});
const classification = makeClassification({ type: "json" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("application/json");
});
it("renders invalid JSON as fallback pre tag", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"
);
const originalImpl = vi
.mocked(globalRegistry.getRenderer)
.getMockImplementation();
// For invalid JSON, JSON.parse throws, then the registry fallback
// also returns null → falls through to <pre>
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("{invalid json!!!"),
}),
);
const artifact = makeArtifact({
id: "json-invalid-001",
title: "bad.json",
mimeType: "application/json",
});
const classification = makeClassification({ type: "json" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
const pre = container.querySelector("pre");
expect(pre).toBeTruthy();
expect(pre?.textContent).toBe("{invalid json!!!");
});
// Restore
if (originalImpl) {
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
}
});
// ── CSV ───────────────────────────────────────────────────────────
it("renders CSV via globalRegistry with text/csv metadata", async () => {
const csvContent = "Name,Age\nAlice,30\nBob,25";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(csvContent),
}),
);
const artifact = makeArtifact({
id: "csv-render-001",
title: "data.csv",
mimeType: "text/csv",
});
const classification = makeClassification({
type: "csv",
hasSourceToggle: true,
});
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("text/csv");
});
it("renders TSV via globalRegistry with tab-separated metadata", async () => {
const tsvContent = "Name\tAge\nAlice\t30\nBob\t25";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(tsvContent),
}),
);
const artifact = makeArtifact({
id: "tsv-render-001",
title: "data.tsv",
mimeType: "text/tab-separated-values",
});
const classification = makeClassification({
type: "csv",
hasSourceToggle: true,
});
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("text/tab-separated-values");
});
// ── Markdown ──────────────────────────────────────────────────────
it("renders markdown via globalRegistry", async () => {
const mdContent = "# Hello\n\nThis is **markdown**.";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(mdContent),
}),
);
const artifact = makeArtifact({
id: "md-render-001",
title: "README.md",
mimeType: "text/markdown",
});
const classification = makeClassification({
type: "markdown",
hasSourceToggle: true,
});
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("text/markdown");
});
// ── Text fallback ─────────────────────────────────────────────────
it("renders text artifacts via globalRegistry fallback", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("plain text content"),
}),
);
const artifact = makeArtifact({
id: "text-render-001",
title: "notes.txt",
mimeType: "text/plain",
});
const classification = makeClassification({ type: "text" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
});
it.each([
{
filename: "calendar.ics",
mimeType: "text/calendar",
content: "BEGIN:VCALENDAR\nVERSION:2.0\nEND:VCALENDAR",
},
{
filename: "contact.vcf",
mimeType: "text/vcard",
content: "BEGIN:VCARD\nVERSION:4.0\nFN:Alice Example\nEND:VCARD",
},
])(
"renders concrete text artifact $filename through the global renderer path",
async ({ filename, mimeType, content }) => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(content),
}),
);
const artifact = makeArtifact({
id: `text-${filename}`,
title: filename,
mimeType,
});
const classification = classifyArtifact(
artifact.mimeType,
artifact.title,
);
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByTestId("global-renderer");
expect(classification.type).toBe("text");
expect(vi.mocked(globalRegistry.getRenderer)).toHaveBeenCalledWith(
content,
expect.objectContaining({
filename,
mimeType,
}),
);
},
);
it("falls back to pre tag when no renderer matches", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"
);
const originalImpl = vi
.mocked(globalRegistry.getRenderer)
.getMockImplementation();
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("raw content fallback"),
}),
);
const artifact = makeArtifact({
id: "fallback-pre-001",
title: "unknown.txt",
mimeType: "text/plain",
});
const classification = makeClassification({ type: "text" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
const pre = container.querySelector("pre");
expect(pre).toBeTruthy();
expect(pre?.textContent).toBe("raw content fallback");
});
// Restore
if (originalImpl) {
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
}
});
});

View File

@@ -3,6 +3,7 @@ import { renderHook, waitFor, act } from "@testing-library/react";
import {
useArtifactContent,
getCachedArtifactContent,
clearContentCache,
} from "../useArtifactContent";
import type { ArtifactRef } from "../../../../store";
import type { ArtifactClassification } from "../../helpers";
@@ -33,6 +34,7 @@ function makeClassification(
describe("useArtifactContent", () => {
beforeEach(() => {
clearContentCache();
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
@@ -44,6 +46,7 @@ describe("useArtifactContent", () => {
});
afterEach(() => {
clearContentCache();
vi.restoreAllMocks();
});
@@ -109,9 +112,12 @@ describe("useArtifactContent", () => {
useArtifactContent(artifact, classification),
);
await waitFor(() => {
expect(result.current.error).toBeTruthy();
});
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("404");
expect(result.current.content).toBeNull();
@@ -132,6 +138,176 @@ describe("useArtifactContent", () => {
expect(getCachedArtifactContent("cache-test")).toBe("file content here");
});
it("sets error on fetch failure for HTML artifacts (stale artifact)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
}),
);
const artifact = makeArtifact({ id: "stale-html-artifact" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("404");
expect(result.current.content).toBeNull();
});
it("sets error on network failure", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockRejectedValue(new Error("Network error")),
);
const artifact = makeArtifact({ id: "network-error-artifact" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("Network error");
expect(result.current.content).toBeNull();
});
it("retries transient HTML fetch failures before surfacing an error", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount < 3) {
return Promise.resolve({
ok: false,
status: 503,
headers: {
get: () => "application/json",
},
json: () => Promise.resolve({ detail: "temporary upstream error" }),
});
}
return Promise.resolve({
ok: true,
text: () => Promise.resolve("<html>ok now</html>"),
});
}),
);
const artifact = makeArtifact({ id: "transient-html-retry" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.content).toBe("<html>ok now</html>");
},
{ timeout: 2500 },
);
expect(callCount).toBe(3);
expect(result.current.error).toBeNull();
});
it("surfaces backend error detail from JSON responses", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 404,
headers: {
get: () => "application/json",
},
json: () => Promise.resolve({ detail: "File not found" }),
}),
);
const artifact = makeArtifact({ id: "json-error-detail" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("404");
expect(result.current.error).toContain("File not found");
});
it("retry after 404 on HTML artifact clears cache and re-fetches", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
});
}
return Promise.resolve({
ok: true,
text: () => Promise.resolve("<html>recovered</html>"),
});
}),
);
const artifact = makeArtifact({ id: "retry-html-artifact" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(() => {
expect(result.current.error).toBeTruthy();
});
act(() => {
result.current.retry();
});
await waitFor(
() => {
expect(result.current.content).toBe("<html>recovered</html>");
},
{ timeout: 2500 },
);
expect(result.current.error).toBeNull();
});
it("retry clears cache and re-fetches", async () => {
let callCount = 0;
vi.stubGlobal(
@@ -164,4 +340,162 @@ describe("useArtifactContent", () => {
expect(result.current.content).toBe("response 2");
});
});
// ── Non-transient errors ──────────────────────────────────────────
it("rejects immediately on 403 without retrying", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({
ok: false,
status: 403,
text: () => Promise.resolve("Forbidden"),
});
}),
);
const artifact = makeArtifact({ id: "forbidden-no-retry" });
const classification = makeClassification({ type: "text" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(callCount).toBe(1);
expect(result.current.error).toContain("403");
});
// ── Video skip-fetch ──────────────────────────────────────────────
it("skips fetch for video artifacts (like image)", async () => {
const artifact = makeArtifact({
id: "video-skip",
mimeType: "video/mp4",
});
const classification = makeClassification({ type: "video" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
expect(result.current.isLoading).toBe(false);
expect(result.current.content).toBeNull();
expect(result.current.pdfUrl).toBeNull();
expect(fetch).not.toHaveBeenCalled();
});
// ── PDF error paths ───────────────────────────────────────────────
it("sets error on PDF fetch failure (non-2xx)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 500,
text: () => Promise.resolve("Server Error"),
}),
);
const artifact = makeArtifact({ id: "pdf-error" });
const classification = makeClassification({ type: "pdf" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("500");
expect(result.current.pdfUrl).toBeNull();
});
it("sets error on PDF network failure", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockRejectedValue(new Error("PDF network failure")),
);
const artifact = makeArtifact({ id: "pdf-network-error" });
const classification = makeClassification({ type: "pdf" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("PDF network failure");
expect(result.current.pdfUrl).toBeNull();
});
// ── LRU cache eviction ────────────────────────────────────────────
it("evicts oldest entry when cache exceeds 12 items", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation((url: string) => {
const fileId = url.match(/files\/([^/]+)\/download/)?.[1] ?? "unknown";
return Promise.resolve({
ok: true,
text: () => Promise.resolve(`content-${fileId}`),
});
}),
);
const classification = makeClassification({ type: "text" });
// Fill the cache with 12 entries (cache max = 12)
for (let i = 0; i < 12; i++) {
const artifact = makeArtifact({
id: `lru-${i}`,
sourceUrl: `/api/proxy/api/workspace/files/lru-${i}/download`,
});
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(() => {
expect(result.current.isLoading).toBe(false);
});
}
// All 12 should be cached
expect(getCachedArtifactContent("lru-0")).toBe("content-lru-0");
expect(getCachedArtifactContent("lru-11")).toBe("content-lru-11");
// Adding a 13th should evict lru-0 (the oldest)
const artifact13 = makeArtifact({
id: "lru-12",
sourceUrl: "/api/proxy/api/workspace/files/lru-12/download",
});
const { result: result13 } = renderHook(() =>
useArtifactContent(artifact13, classification),
);
await waitFor(() => {
expect(result13.current.isLoading).toBe(false);
});
expect(getCachedArtifactContent("lru-0")).toBeUndefined();
expect(getCachedArtifactContent("lru-1")).toBe("content-lru-1");
expect(getCachedArtifactContent("lru-12")).toBe("content-lru-12");
});
});

View File

@@ -85,4 +85,35 @@ describe("buildReactArtifactSrcDoc", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("box-sizing: border-box");
});
it("supports a named previewProps export in the runtime", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("moduleExports.previewProps");
expect(doc).toContain("React.createElement(Component, previewProps || {})");
});
it("includes a helpful message for components that expect props", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("This component appears to expect props.");
expect(doc).toContain("previewProps");
});
it("checks componentExpectsProps on the raw component before wrapping", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("RawComponent.length > 0");
expect(doc).toContain("wrapWithProviders(RawComponent");
});
it("wrapWithProviders forwards props to the wrapped component", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("function WrappedArtifactPreview(props)");
expect(doc).toContain("React.createElement(Component, props)");
});
it("supports named exported components and provider wrappers in the runtime", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain('name.endsWith("Provider")');
expect(doc).toContain("/^[A-Z]/.test(name)");
expect(doc).toContain("wrapWithProviders");
});
});

View File

@@ -169,8 +169,8 @@ export function buildReactArtifactSrcDoc(
return Component;
}
return function WrappedArtifactPreview() {
let tree = React.createElement(Component);
return function WrappedArtifactPreview(props) {
let tree = React.createElement(Component, props);
for (let i = providers.length - 1; i >= 0; i -= 1) {
tree = React.createElement(providers[i], null, tree);
@@ -180,6 +180,17 @@ export function buildReactArtifactSrcDoc(
};
}
function getPreviewProps(moduleExports) {
if (
moduleExports.previewProps &&
typeof moduleExports.previewProps === "object"
) {
return moduleExports.previewProps;
}
return null;
}
function require(name) {
if (name === "react") {
return React;
@@ -235,6 +246,11 @@ export function buildReactArtifactSrcDoc(
render() {
if (this.state.error) {
const propsHelp =
this.props.componentExpectsProps && !this.props.hasPreviewProps
? "\\n\\nThis component appears to expect props. Export a named previewProps object with sample values to render it in artifact preview."
: "";
return React.createElement(
"div",
{
@@ -249,7 +265,9 @@ export function buildReactArtifactSrcDoc(
whiteSpace: "pre-wrap",
},
},
this.state.error.stack || this.state.error.message || String(this.state.error),
(this.state.error.stack ||
this.state.error.message ||
String(this.state.error)) + propsHelp,
);
}
@@ -296,16 +314,19 @@ export function buildReactArtifactSrcDoc(
moduleExports.App = executionResult.app;
}
const Component = wrapWithProviders(
getRenderableCandidate(moduleExports),
moduleExports,
);
const RawComponent = getRenderableCandidate(moduleExports);
const componentExpectsProps = RawComponent.length > 0;
const Component = wrapWithProviders(RawComponent, moduleExports);
const previewProps = getPreviewProps(moduleExports);
ReactDOM.createRoot(rootElement).render(
React.createElement(
PreviewErrorBoundary,
null,
React.createElement(Component),
{
componentExpectsProps: componentExpectsProps,
hasPreviewProps: previewProps != null,
},
React.createElement(Component, previewProps || {}),
),
);
} catch (error) {

View File

@@ -48,4 +48,104 @@ describe("transpileReactArtifactSource", () => {
expect(out).not.toContain(": string");
expect(out).toContain("function greet(name)");
});
it("transpiles a concrete props-based artifact with previewProps", async () => {
const src = `
import React, { FC, useState } from "react";
interface ArtifactFile {
id: string;
name: string;
mimeType: string;
url: string;
sizeBytes: number;
}
interface Props {
files: ArtifactFile[];
onSelect: (file: ArtifactFile) => void;
}
export const previewProps: Props = {
files: [
{
id: "1",
name: "report.png",
mimeType: "image/png",
url: "/report.png",
sizeBytes: 2048,
},
],
onSelect: () => {},
};
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
const [selected, setSelected] = useState<string | null>(null);
const handleClick = (file: ArtifactFile) => {
setSelected(file.id);
onSelect(file);
};
return (
<ul>
{files.map((file) => (
<li key={file.id} onClick={() => handleClick(file)}>
<span>{selected === file.id ? "selected" : file.name}</span>
</li>
))}
</ul>
);
};
export default ArtifactList;
`;
const out = await transpileReactArtifactSource(src, "ArtifactList.tsx");
expect(out).toContain("exports.previewProps");
expect(out).toContain("exports.default = ArtifactList");
expect(out).toContain("useState");
expect(out).not.toContain("interface Props");
expect(out).not.toContain("interface ArtifactFile");
});
it("transpiles a named export artifact without a default export", async () => {
const src = `
export function ResultsGrid() {
return (
<section>
<h1>Results</h1>
<p>Named export preview</p>
</section>
);
}
`;
const out = await transpileReactArtifactSource(src, "ResultsGrid.tsx");
expect(out).toContain("exports.ResultsGrid = ResultsGrid");
expect(out).toMatch(/\.createElement\(/);
expect(out).not.toContain("<section>");
});
it("transpiles a provider-wrapped artifact with separate provider and component exports", async () => {
const src = `
import React from "react";
export function DemoProvider({ children }: { children: React.ReactNode }) {
return <div data-theme="demo">{children}</div>;
}
export function DashboardCard() {
return <main>Provider-backed preview</main>;
}
`;
const out = await transpileReactArtifactSource(src, "DashboardCard.tsx");
expect(out).toContain("exports.DemoProvider = DemoProvider");
expect(out).toContain("exports.DashboardCard = DashboardCard");
expect(out).not.toContain("React.ReactNode");
});
});

View File

@@ -7,12 +7,116 @@ import type { ArtifactClassification } from "../helpers";
// Cap on cached text artifacts. Long sessions with many large artifacts
// would otherwise hold every opened one in memory.
const CONTENT_CACHE_MAX = 12;
const CONTENT_FETCH_MAX_RETRIES = 2;
const CONTENT_FETCH_RETRY_DELAY_MS = 500;
// Module-level LRU keyed by artifact id so a sibling action (e.g. Copy
// in ArtifactPanelHeader) can read what the panel already fetched without
// re-hitting the network.
const contentCache = new Map<string, string>();
class ArtifactFetchError extends Error {}
function isTransientArtifactFetchStatus(status: number): boolean {
return status === 408 || status === 429 || status >= 500;
}
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
function getArtifactErrorMessage(body: unknown): string | null {
if (typeof body === "string") {
const trimmed = body.replace(/\s+/g, " ").trim();
return trimmed || null;
}
if (!body || typeof body !== "object") return null;
if (
"detail" in body &&
typeof body.detail === "string" &&
body.detail.trim().length > 0
) {
return body.detail.trim();
}
if (
"error" in body &&
typeof body.error === "string" &&
body.error.trim().length > 0
) {
return body.error.trim();
}
if (
"detail" in body &&
body.detail &&
typeof body.detail === "object" &&
"message" in body.detail &&
typeof body.detail.message === "string" &&
body.detail.message.trim().length > 0
) {
return body.detail.message.trim();
}
return null;
}
async function parseArtifactFetchError(response: Response): Promise<string> {
const prefix = `Failed to fetch: ${response.status}`;
const contentType =
response.headers?.get?.("content-type")?.toLowerCase() ?? "";
try {
if (
contentType.includes("application/json") &&
typeof response.json === "function"
) {
const body = await response.json();
const detail = getArtifactErrorMessage(body);
return detail ? `${prefix} ${detail}` : prefix;
}
if (typeof response.text === "function") {
const text = await response.text();
const detail = getArtifactErrorMessage(text);
return detail ? `${prefix} ${detail}` : prefix;
}
} catch {
return prefix;
}
return prefix;
}
async function fetchArtifactResponse(url: string): Promise<Response> {
for (let attempt = 0; attempt <= CONTENT_FETCH_MAX_RETRIES; attempt++) {
try {
const response = await fetch(url);
if (response.ok) return response;
if (
!isTransientArtifactFetchStatus(response.status) ||
attempt === CONTENT_FETCH_MAX_RETRIES
) {
throw new ArtifactFetchError(await parseArtifactFetchError(response));
}
} catch (error) {
if (error instanceof ArtifactFetchError) throw error;
if (attempt === CONTENT_FETCH_MAX_RETRIES) {
throw error instanceof Error
? error
: new Error("Failed to fetch artifact");
}
}
await sleep(CONTENT_FETCH_RETRY_DELAY_MS);
}
throw new Error("Failed to fetch artifact");
}
export function getCachedArtifactContent(id: string): string | undefined {
return contentCache.get(id);
}
@@ -64,7 +168,7 @@ export function useArtifactContent(
}, [artifact.id, isLoading]);
useEffect(() => {
if (classification.type === "image") {
if (classification.type === "image" || classification.type === "video") {
setContent(null);
setPdfUrl(null);
setError(null);
@@ -80,11 +184,8 @@ export function useArtifactContent(
let objectUrl: string | null = null;
setContent(null);
setPdfUrl(null);
fetch(artifact.sourceUrl)
.then((res) => {
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
return res.blob();
})
fetchArtifactResponse(artifact.sourceUrl)
.then((res) => res.blob())
.then((blob) => {
objectUrl = URL.createObjectURL(blob);
if (cancelled) {
@@ -121,11 +222,8 @@ export function useArtifactContent(
cancelled = true;
};
}
fetch(artifact.sourceUrl)
.then((res) => {
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
return res.text();
})
fetchArtifactResponse(artifact.sourceUrl)
.then((res) => res.text())
.then((text) => {
if (!cancelled) {
if (cache.size >= CONTENT_CACHE_MAX) {

View File

@@ -1,5 +1,31 @@
import type { ArtifactRef } from "../../store";
const MAX_RETRIES = 2;
const RETRY_DELAY_MS = 500;
function isTransientError(status: number): boolean {
return status >= 500 || status === 408 || status === 429;
}
class DownloadError extends Error {}
async function fetchWithRetry(url: string, retries: number): Promise<Response> {
for (let attempt = 0; attempt <= retries; attempt++) {
try {
const res = await fetch(url);
if (res.ok) return res;
if (!isTransientError(res.status) || attempt === retries) {
throw new DownloadError(`Download failed: ${res.status}`);
}
} catch (error) {
if (error instanceof DownloadError) throw error;
if (attempt === retries) throw error;
}
await new Promise((r) => setTimeout(r, RETRY_DELAY_MS));
}
throw new Error("Unreachable");
}
/**
* Trigger a file download from an artifact URL.
*
@@ -7,26 +33,28 @@ import type { ArtifactRef } from "../../store";
* ignores the `download` attribute on cross-origin responses (GCS signed
* URLs), and some browsers require the anchor to be attached to the DOM
* before `.click()` fires the download.
*
* Retries up to {@link MAX_RETRIES} times on transient server errors (5xx,
* 408, 429) to handle intermittent proxy/GCS failures.
*/
export function downloadArtifact(artifact: ArtifactRef): Promise<void> {
// Replace path separators, Windows-reserved chars, control chars, and
// parent-dir sequences so the browser-assigned filename is safe to write
// anywhere on the user's filesystem.
const safeName =
artifact.title
.replace(/\.\./g, "_")
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
.replace(/^\.+/, "") || "download";
return fetch(artifact.sourceUrl)
.then((res) => {
if (!res.ok) throw new Error(`Download failed: ${res.status}`);
return res.blob();
})
const collapsedDots = artifact.title.replace(/\.\./g, "");
const hasVisibleName = collapsedDots.replace(/^\.+/, "").length > 0;
const safeName = artifact.title
.replace(/\.\./g, "_")
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
.replace(/^\.+/, "");
return fetchWithRetry(artifact.sourceUrl, MAX_RETRIES)
.then((res) => res.blob())
.then((blob) => {
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = safeName;
a.download = safeName && hasVisibleName ? safeName : "download";
document.body.appendChild(a);
a.click();
a.remove();

View File

@@ -56,7 +56,7 @@ describe("classifyArtifact", () => {
expect(classifyArtifact("application/octet-stream", "x").openable).toBe(
false,
);
expect(classifyArtifact("video/mp4", "clip.mp4").openable).toBe(false);
expect(classifyArtifact("audio/mpeg", "track.mp3").openable).toBe(false);
});
it("defaults unknown extension+MIME to download-only (not text)", () => {
@@ -76,4 +76,398 @@ describe("classifyArtifact", () => {
const c = classifyArtifact("text/plain", "data.csv");
expect(c.type).toBe("csv");
});
it("classifies video/mp4 as video (previewable)", () => {
const c = classifyArtifact("video/mp4", "clip.mp4");
expect(c.type).toBe("video");
expect(c.openable).toBe(true);
});
it("classifies video/webm as video (previewable)", () => {
const c = classifyArtifact("video/webm", "clip.webm");
expect(c.type).toBe("video");
expect(c.openable).toBe(true);
});
// ── Extension coverage ────────────────────────────────────────────
it("routes .htm as html (not just .html)", () => {
const c = classifyArtifact(null, "page.htm");
expect(c.type).toBe("html");
expect(c.hasSourceToggle).toBe(true);
});
it("routes .json as json with source toggle", () => {
const c = classifyArtifact(null, "config.json");
expect(c.type).toBe("json");
expect(c.hasSourceToggle).toBe(true);
});
it("routes .txt as text", () => {
expect(classifyArtifact(null, "notes.txt").type).toBe("text");
});
it("routes .log as text", () => {
expect(classifyArtifact(null, "server.log").type).toBe("text");
});
it("routes .mdx as markdown", () => {
expect(classifyArtifact(null, "docs.mdx").type).toBe("markdown");
});
it("routes browser-safe video extensions to video", () => {
for (const ext of [".mp4", ".webm", ".m4v"]) {
const c = classifyArtifact(null, `clip${ext}`);
expect(c.type).toBe("video");
expect(c.openable).toBe(true);
}
});
it("keeps legacy or unsupported video extensions download-only", () => {
for (const ext of [".ogg", ".mov", ".avi", ".mkv", ".flv", ".mpeg"]) {
const c = classifyArtifact(null, `clip${ext}`);
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
}
});
it("routes all code extensions to code", () => {
const codeExts = [
"main.js",
"app.ts",
"theme.scss",
"legacy.less",
"schema.graphql",
"query.gql",
"api.proto",
"main.dart",
"lib.rb",
"server.rs",
"App.java",
"main.c",
"util.cpp",
"header.h",
"Program.cs",
"index.php",
"main.swift",
"App.kt",
"run.sh",
"start.bash",
"prompt.zsh",
"config.toml",
"settings.ini",
"app.cfg",
"query.sql",
"analysis.r",
"game.lua",
"script.pl",
"Calc.scala",
];
for (const file of codeExts) {
expect(classifyArtifact(null, file).type).toBe("code");
}
});
it("routes config filenames and extensions to code", () => {
const configFiles = [
".env",
".env.local",
"app.properties",
"service.conf",
".gitignore",
"Dockerfile",
"Makefile",
];
for (const file of configFiles) {
expect(classifyArtifact(null, file).type).toBe("code");
}
});
it("routes .jsonl as code for now", () => {
const c = classifyArtifact(null, "events.jsonl");
expect(c.type).toBe("code");
});
it("routes .tsv as csv/spreadsheet", () => {
const c = classifyArtifact(null, "table.tsv");
expect(c.type).toBe("csv");
expect(c.hasSourceToggle).toBe(true);
});
it("routes .ics and .vcf as text", () => {
expect(classifyArtifact(null, "calendar.ics").type).toBe("text");
expect(classifyArtifact(null, "contact.vcf").type).toBe("text");
});
it("routes all image extensions to image", () => {
for (const ext of [".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".ico"]) {
expect(classifyArtifact(null, `file${ext}`).type).toBe("image");
}
});
// ── MIME fallback coverage ────────────────────────────────────────
it("routes application/json MIME to json", () => {
const c = classifyArtifact("application/json", "noext");
expect(c.type).toBe("json");
});
it("routes text/x-* MIME prefix to code", () => {
expect(classifyArtifact("text/x-python", "noext").type).toBe("code");
expect(classifyArtifact("text/x-c", "noext").type).toBe("code");
expect(classifyArtifact("text/x-java-source", "noext").type).toBe("code");
});
it("routes react MIME types to react", () => {
expect(classifyArtifact("text/jsx", "noext").type).toBe("react");
expect(classifyArtifact("text/tsx", "noext").type).toBe("react");
expect(classifyArtifact("application/jsx", "noext").type).toBe("react");
expect(classifyArtifact("application/x-typescript-jsx", "noext").type).toBe(
"react",
);
});
it("routes JavaScript/TypeScript MIME to code", () => {
expect(classifyArtifact("application/javascript", "noext").type).toBe(
"code",
);
expect(classifyArtifact("text/javascript", "noext").type).toBe("code");
expect(classifyArtifact("application/typescript", "noext").type).toBe(
"code",
);
expect(classifyArtifact("text/typescript", "noext").type).toBe("code");
});
it("routes XML MIME to code", () => {
expect(classifyArtifact("application/xml", "noext").type).toBe("code");
expect(classifyArtifact("text/xml", "noext").type).toBe("code");
});
it("routes text/x-markdown MIME to markdown", () => {
expect(classifyArtifact("text/x-markdown", "noext").type).toBe("markdown");
});
it("routes text/csv MIME to csv", () => {
expect(classifyArtifact("text/csv", "noext").type).toBe("csv");
});
it("routes TSV MIME to csv", () => {
expect(classifyArtifact("text/tab-separated-values", "noext").type).toBe(
"csv",
);
});
it("routes unknown text/* MIME to text (not download-only)", () => {
expect(classifyArtifact("text/rtf", "noext").type).toBe("text");
});
it("routes browser-safe image MIME types to image", () => {
expect(classifyArtifact("image/avif", "noext").type).toBe("image");
});
it("keeps unsupported image MIME types download-only", () => {
for (const mime of [
"image/tiff",
"image/x-portable-pixmap",
"image/x-portable-graymap",
]) {
const c = classifyArtifact(mime, "noext");
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
}
});
it("routes browser-safe video MIME types to video", () => {
expect(classifyArtifact("video/mp4", "noext").type).toBe("video");
expect(classifyArtifact("video/webm", "noext").type).toBe("video");
});
it("keeps legacy or unsupported video MIME types download-only", () => {
for (const mime of [
"video/x-msvideo",
"video/x-flv",
"video/mpeg",
"video/quicktime",
"video/x-matroska",
"video/ogg",
]) {
const c = classifyArtifact(mime, "noext");
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
}
});
// ── BINARY_MIMES coverage ────────────────────────────────────────
it("treats all BINARY_MIMES entries as download-only", () => {
const binaryMimes = [
"application/zip",
"application/x-zip-compressed",
"application/gzip",
"application/x-tar",
"application/x-rar-compressed",
"application/x-7z-compressed",
"application/octet-stream",
"application/x-executable",
"application/x-msdos-program",
"application/vnd.microsoft.portable-executable",
];
for (const mime of binaryMimes) {
const c = classifyArtifact(mime, "noext");
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
}
});
it("treats audio/* MIME as download-only", () => {
expect(classifyArtifact("audio/mpeg", "noext").openable).toBe(false);
expect(classifyArtifact("audio/wav", "noext").openable).toBe(false);
expect(classifyArtifact("audio/ogg", "noext").openable).toBe(false);
});
// ── Size gate edge cases ──────────────────────────────────────────
it("does NOT gate files at exactly 10MB (boundary is >10MB)", () => {
const tenMB = 10 * 1024 * 1024;
const c = classifyArtifact("text/plain", "exact.txt", tenMB);
expect(c.type).toBe("text");
expect(c.openable).toBe(true);
});
it("gates files at 10MB + 1 byte", () => {
const overTenMB = 10 * 1024 * 1024 + 1;
const c = classifyArtifact("text/plain", "big.txt", overTenMB);
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
});
it("does not gate when sizeBytes is 0", () => {
const c = classifyArtifact("text/plain", "empty.txt", 0);
expect(c.type).toBe("text");
expect(c.openable).toBe(true);
});
it("does not gate when sizeBytes is undefined", () => {
const c = classifyArtifact("text/plain", "file.txt", undefined);
expect(c.type).toBe("text");
expect(c.openable).toBe(true);
});
// ── Extension over MIME priority ──────────────────────────────────
it("extension wins over MIME for JSON (MIME says text, ext says json)", () => {
const c = classifyArtifact("text/plain", "data.json");
expect(c.type).toBe("json");
});
it("extension wins over MIME for markdown", () => {
const c = classifyArtifact("text/plain", "README.md");
expect(c.type).toBe("markdown");
});
// ── Null/missing inputs ───────────────────────────────────────────
it("handles null MIME with no filename as download-only", () => {
const c = classifyArtifact(null, undefined);
expect(c.type).toBe("download-only");
});
it("handles null MIME with empty filename as download-only", () => {
const c = classifyArtifact(null, "");
expect(c.type).toBe("download-only");
});
it("handles known config files with no extension", () => {
const c = classifyArtifact(null, "Makefile");
expect(c.type).toBe("code");
});
// ── Exotic/compound extensions must NOT open the side panel ───────
// These are real file types agents might produce. Every single one
// must be download-only so we never try to render binary garbage.
it("does not open .tar.gz (compound extension takes last segment)", () => {
// getExtension("archive.tar.gz") → ".gz" which is not in EXT_KIND
const c = classifyArtifact(null, "archive.tar.gz");
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
});
it("does not open .tar.bz2", () => {
const c = classifyArtifact(null, "archive.tar.bz2");
expect(c.openable).toBe(false);
});
it("does not open .tar.xz", () => {
const c = classifyArtifact(null, "archive.tar.xz");
expect(c.openable).toBe(false);
});
it("does not open common binary formats", () => {
const binaries = [
"setup.exe",
"library.dll",
"image.iso",
"installer.dmg",
"package.deb",
"package.rpm",
"module.wasm",
"Main.class",
"module.pyc",
"app.apk",
"game.pak",
"model.onnx",
"weights.pt",
"data.parquet",
"archive.rar",
"archive.7z",
"disk.vhd",
"disk.vmdk",
"firmware.bin",
"core.dump",
"database.sqlite",
"database.db",
"index.idx",
];
for (const file of binaries) {
const c = classifyArtifact(null, file);
expect(c.openable).toBe(false);
}
});
it("does not open binary MIME types even with a misleading extension", () => {
// Extension is unknown, MIME is binary
const c = classifyArtifact("application/x-executable", "run.elf");
expect(c.openable).toBe(false);
});
it("does not open files with random/made-up extensions", () => {
const weirdExts = [
"output.xyz",
"data.foo",
"file.asdf",
"thing.blargh",
"result.out",
"x.1234",
];
for (const file of weirdExts) {
const c = classifyArtifact(null, file);
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
}
});
it("does not open font files", () => {
for (const file of ["sans.ttf", "serif.otf", "icon.woff", "icon.woff2"]) {
expect(classifyArtifact(null, file).openable).toBe(false);
}
});
it("does not open certificate/key files", () => {
// .pem and .key have no extension mapping and null MIME → download-only
for (const file of ["cert.pem", "server.key", "ca.crt", "id.p12"]) {
expect(classifyArtifact(null, file).openable).toBe(false);
}
});
});

View File

@@ -5,6 +5,7 @@ import {
FileText,
Image,
Table,
VideoCamera,
} from "@phosphor-icons/react";
import type { Icon } from "@phosphor-icons/react";
@@ -17,6 +18,7 @@ export interface ArtifactClassification {
| "csv"
| "json"
| "image"
| "video"
| "pdf"
| "text"
| "download-only";
@@ -38,6 +40,13 @@ const KIND: Record<string, ArtifactClassification> = {
openable: true,
hasSourceToggle: false,
},
video: {
type: "video",
icon: VideoCamera,
label: "Video",
openable: true,
hasSourceToggle: false,
},
pdf: {
type: "pdf",
icon: FileText,
@@ -113,8 +122,13 @@ const EXT_KIND: Record<string, string> = {
".svg": "image",
".bmp": "image",
".ico": "image",
".avif": "image",
".mp4": "video",
".webm": "video",
".m4v": "video",
".pdf": "pdf",
".csv": "csv",
".tsv": "csv",
".html": "html",
".htm": "html",
".jsx": "react",
@@ -122,11 +136,17 @@ const EXT_KIND: Record<string, string> = {
".md": "markdown",
".mdx": "markdown",
".json": "json",
".jsonl": "code",
".txt": "text",
".log": "text",
".ics": "text",
".vcf": "text",
".env": "code",
".gitignore": "code",
// code extensions
".js": "code",
".ts": "code",
".dart": "code",
".py": "code",
".rb": "code",
".go": "code",
@@ -142,11 +162,19 @@ const EXT_KIND: Record<string, string> = {
".sh": "code",
".bash": "code",
".zsh": "code",
".scss": "code",
".sass": "code",
".less": "code",
".graphql": "code",
".gql": "code",
".proto": "code",
".yml": "code",
".yaml": "code",
".toml": "code",
".ini": "code",
".cfg": "code",
".conf": "code",
".properties": "code",
".sql": "code",
".r": "code",
".lua": "code",
@@ -154,10 +182,16 @@ const EXT_KIND: Record<string, string> = {
".scala": "code",
};
const EXACT_FILENAME_KIND: Record<string, string> = {
dockerfile: "code",
makefile: "code",
};
// Exact-match MIME → kind (fallback when extension doesn't match).
const MIME_KIND: Record<string, string> = {
"application/pdf": "pdf",
"text/csv": "csv",
"text/tab-separated-values": "csv",
"text/html": "html",
"text/jsx": "react",
"text/tsx": "react",
@@ -166,6 +200,9 @@ const MIME_KIND: Record<string, string> = {
"text/markdown": "markdown",
"text/x-markdown": "markdown",
"application/json": "json",
"application/x-ndjson": "code",
"application/ndjson": "code",
"application/jsonl": "code",
"application/javascript": "code",
"text/javascript": "code",
"application/typescript": "code",
@@ -182,11 +219,37 @@ const BINARY_MIMES = new Set([
"application/x-rar-compressed",
"application/x-7z-compressed",
"application/octet-stream",
"application/wasm",
"application/x-executable",
"application/x-msdos-program",
"application/vnd.microsoft.portable-executable",
]);
const PREVIEWABLE_IMAGE_MIMES = new Set([
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/svg+xml",
"image/bmp",
"image/x-icon",
"image/vnd.microsoft.icon",
"image/avif",
]);
const PREVIEWABLE_VIDEO_MIMES = new Set([
"video/mp4",
"video/webm",
"video/x-m4v",
]);
function getBasename(filename?: string): string {
if (!filename) return "";
const normalized = filename.replace(/\\/g, "/");
const parts = normalized.split("/");
return parts[parts.length - 1]?.toLowerCase() ?? "";
}
function getExtension(filename?: string): string {
if (!filename) return "";
const lastDot = filename.lastIndexOf(".");
@@ -202,24 +265,36 @@ export function classifyArtifact(
// Size gate: >10MB is download-only regardless of type.
if (sizeBytes && sizeBytes > TEN_MB) return KIND["download-only"];
const basename = getBasename(filename);
const exactKind = EXACT_FILENAME_KIND[basename];
if (exactKind) return KIND[exactKind];
if (basename === ".env" || basename.startsWith(".env.")) {
return KIND.code;
}
// Extension first (more reliable than MIME for AI-generated files).
const ext = getExtension(filename);
const ext = getExtension(basename);
const extKind = EXT_KIND[ext];
if (extKind) return KIND[extKind];
// MIME fallbacks.
const mime = (mimeType ?? "").toLowerCase();
if (mime.startsWith("image/")) return KIND.image;
if (PREVIEWABLE_IMAGE_MIMES.has(mime)) return KIND.image;
if (PREVIEWABLE_VIDEO_MIMES.has(mime)) return KIND.video;
const mimeKind = MIME_KIND[mime];
if (mimeKind) return KIND[mimeKind];
if (mime.startsWith("text/x-")) return KIND.code;
if (
BINARY_MIMES.has(mime) ||
mime.startsWith("audio/") ||
mime.startsWith("video/")
mime.startsWith("image/") ||
mime.startsWith("video/") ||
mime.startsWith("font/")
) {
return KIND["download-only"];
}
if (BINARY_MIMES.has(mime) || mime.startsWith("audio/")) {
return KIND["download-only"];
}
if (mime.startsWith("text/")) return KIND.text;
// Unknown extension + unknown MIME: don't open — we can't safely assume

View File

@@ -83,6 +83,7 @@ export function useArtifactPanel() {
const canCopy =
classification != null &&
classification.type !== "image" &&
classification.type !== "video" &&
classification.type !== "download-only" &&
classification.type !== "pdf";

View File

@@ -64,10 +64,7 @@ export const ChatContainer = ({
// open state drive layout width; an artifact generated in a stale session
// state would otherwise shrink the chat column with no panel rendered.
const isArtifactOpen = isArtifactsEnabled && isArtifactPanelOpen;
useAutoOpenArtifacts({
messages: isArtifactsEnabled ? messages : [],
sessionId,
});
useAutoOpenArtifacts({ sessionId });
const isBusy =
status === "streaming" ||
status === "submitted" ||

View File

@@ -0,0 +1,77 @@
import { describe, expect, it, beforeEach, afterEach } from "vitest";
import { renderHook } from "@testing-library/react";
import { useAutoOpenArtifacts } from "../useAutoOpenArtifacts";
import { useCopilotUIStore } from "../../../store";
// Capture the real store actions before any test can replace them.
const realOpenArtifact = useCopilotUIStore.getState().openArtifact;
const realResetArtifactPanel = useCopilotUIStore.getState().resetArtifactPanel;
function resetStore() {
useCopilotUIStore.setState({
openArtifact: realOpenArtifact,
resetArtifactPanel: realResetArtifactPanel,
artifactPanel: {
isOpen: false,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: null,
history: [],
},
});
}
describe("useAutoOpenArtifacts", () => {
beforeEach(resetStore);
afterEach(resetStore);
it("does not auto-open artifacts on initial message load", () => {
renderHook(() => useAutoOpenArtifacts({ sessionId: "session-1" }));
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("does not auto-open when rerendering within the same session", () => {
const { rerender } = renderHook(
({ sessionId }: { sessionId: string }) =>
useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "session-1" } },
);
rerender({ sessionId: "session-1" });
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("panel should fully reset when session changes", () => {
const artifact = {
id: "file1",
title: "image.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/file1/download",
origin: "agent" as const,
};
useCopilotUIStore.getState().openArtifact(artifact);
useCopilotUIStore.getState().openArtifact({
...artifact,
id: "file2",
title: "second.png",
sourceUrl: "/api/proxy/api/workspace/files/file2/download",
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
const { rerender } = renderHook(
({ sessionId }: { sessionId: string }) =>
useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "session-1" } },
);
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
rerender({ sessionId: "session-2" });
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
});

View File

@@ -3,17 +3,19 @@ import { beforeEach, describe, expect, it } from "vitest";
import { useCopilotUIStore } from "../../store";
import { useAutoOpenArtifacts } from "./useAutoOpenArtifacts";
function assistantMessageWithText(id: string, text: string) {
return {
id,
role: "assistant" as const,
parts: [{ type: "text" as const, text }],
};
}
const A_ID = "11111111-0000-0000-0000-000000000000";
const B_ID = "22222222-0000-0000-0000-000000000000";
function makeArtifact(id: string, title = `${id}.txt`) {
return {
id,
title,
mimeType: "text/plain",
sourceUrl: `/api/proxy/api/workspace/files/${id}/download`,
origin: "agent" as const,
};
}
function resetStore() {
useCopilotUIStore.setState({
artifactPanel: {
@@ -30,111 +32,60 @@ function resetStore() {
describe("useAutoOpenArtifacts", () => {
beforeEach(resetStore);
it("does NOT auto-open on the initial hydration of message list (baseline pass)", () => {
const messages = [
assistantMessageWithText("m1", `[a](workspace://${A_ID})`),
];
renderHook(() =>
useAutoOpenArtifacts({ messages: messages as any, sessionId: "s1" }),
);
// Initial run just records the baseline fingerprint; nothing opens.
it("does not auto-open on initial render", () => {
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("auto-opens when an existing assistant message adds a new artifact", () => {
// 1st render: baseline with no artifact.
const initial = [assistantMessageWithText("m1", "thinking...")];
it("does not auto-open when rerendering within the same session", () => {
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{ initialProps: { messages: initial, sessionId: "s1" } },
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "s1" } },
);
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
// 2nd render: same message id now contains an artifact link.
act(() => {
rerender({
messages: [
assistantMessageWithText("m1", `here: [A](workspace://${A_ID})`),
],
sessionId: "s1",
});
rerender({ sessionId: "s1" });
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("resets the panel state when sessionId changes", () => {
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
const { rerender } = renderHook(
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "s1" } },
);
act(() => {
rerender({ sessionId: "s2" });
});
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(true);
expect(s.activeArtifact?.id).toBe(A_ID);
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
it("does not re-open when the fingerprint hasn't changed", () => {
const msg = assistantMessageWithText("m1", `[A](workspace://${A_ID})`);
it("does not carry a stale back stack into the next session", () => {
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{ initialProps: { messages: [msg], sessionId: "s1" } },
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "s1" } },
);
// Baseline captured; no open.
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
// Rerender identical content: no change in fingerprint → no open.
act(() => {
rerender({ messages: [msg], sessionId: "s1" });
rerender({ sessionId: "s2" });
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("auto-opens when a brand-new assistant message arrives after the baseline is established", () => {
// First render: one message without artifacts → establishes baseline.
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{
initialProps: {
messages: [assistantMessageWithText("m1", "plain")] as any,
sessionId: "s1",
},
},
);
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
useCopilotUIStore.getState().openArtifact(makeArtifact("c", "c.txt"));
// Second render: a *new* assistant message with an artifact. Baseline
// is already set, so this should auto-open.
act(() => {
rerender({
messages: [
assistantMessageWithText("m1", "plain"),
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
] as any,
sessionId: "s1",
});
});
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(true);
expect(s.activeArtifact?.id).toBe(B_ID);
});
it("resets hydration baseline when sessionId changes", () => {
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{
initialProps: {
messages: [
assistantMessageWithText("m1", `[A](workspace://${A_ID})`),
] as any,
sessionId: "s1",
},
},
);
// Switch to a new session — the first pass on the new session should
// NOT auto-open (it's a fresh hydration).
act(() => {
rerender({
messages: [
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
] as any,
sessionId: "s2",
});
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
expect(s.activeArtifact?.id).toBe("c");
expect(s.history).toEqual([]);
});
});

View File

@@ -1,91 +1,29 @@
"use client";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useRef } from "react";
import type { ArtifactRef } from "../../store";
import { useCopilotUIStore } from "../../store";
import { getMessageArtifacts } from "../ChatMessagesContainer/helpers";
function fingerprintArtifacts(artifacts: ArtifactRef[]): string {
return artifacts
.map((a) => `${a.id}:${a.title}:${a.mimeType ?? ""}:${a.sourceUrl}`)
.join("|");
}
interface UseAutoOpenArtifactsOptions {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
sessionId: string | null;
}
export function useAutoOpenArtifacts({
messages,
sessionId,
}: UseAutoOpenArtifactsOptions) {
const openArtifact = useCopilotUIStore((state) => state.openArtifact);
const messageFingerprintsRef = useRef<Map<string, string>>(new Map());
const hasInitializedRef = useRef(false);
const resetArtifactPanel = useCopilotUIStore(
(state) => state.resetArtifactPanel,
);
const prevSessionIdRef = useRef(sessionId);
useEffect(() => {
messageFingerprintsRef.current = new Map();
hasInitializedRef.current = false;
}, [sessionId]);
const isSessionChange = prevSessionIdRef.current !== sessionId;
prevSessionIdRef.current = sessionId;
useEffect(() => {
if (messages.length === 0) {
messageFingerprintsRef.current = new Map();
return;
// Artifact previews should open only from an explicit user click.
// When the session changes, fully clear the panel state so stale
// active artifacts and back-stack entries never bleed into the next chat.
if (isSessionChange) {
resetArtifactPanel();
}
// Only scan messages whose fingerprint might have changed since the
// last pass: that's the last assistant message (currently streaming)
// plus any assistant message whose id isn't in the baseline yet.
// This keeps the cost O(new+tail), not O(all messages), on every chunk.
const previous = messageFingerprintsRef.current;
const nextFingerprints = new Map<string, string>(previous);
let nextArtifact: ArtifactRef | null = null;
const lastAssistantIdx = (() => {
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === "assistant") return i;
}
return -1;
})();
for (let i = 0; i < messages.length; i++) {
const message = messages[i];
if (message.role !== "assistant") continue;
const isTailAssistant = i === lastAssistantIdx;
const isNewMessage = !previous.has(message.id);
if (!isTailAssistant && !isNewMessage) continue;
const artifacts = getMessageArtifacts(message);
const fingerprint = fingerprintArtifacts(artifacts);
nextFingerprints.set(message.id, fingerprint);
if (!hasInitializedRef.current || fingerprint.length === 0) {
continue;
}
const previousFingerprint = previous.get(message.id) ?? "";
if (previousFingerprint === fingerprint) continue;
nextArtifact = artifacts[artifacts.length - 1] ?? nextArtifact;
}
// Drop entries for messages that no longer exist (e.g. history truncated).
const liveIds = new Set(messages.map((m) => m.id));
for (const id of nextFingerprints.keys()) {
if (!liveIds.has(id)) nextFingerprints.delete(id);
}
messageFingerprintsRef.current = nextFingerprints;
if (!hasInitializedRef.current) {
hasInitializedRef.current = true;
return;
}
if (nextArtifact) {
openArtifact(nextArtifact);
}
}, [messages, openArtifact]);
}, [sessionId, resetArtifactPanel]);
}

View File

@@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react";
import { AttachmentMenu } from "./components/AttachmentMenu";
import { DryRunToggleButton } from "./components/DryRunToggleButton";
import { FileChips } from "./components/FileChips";
import { ModelToggleButton } from "./components/ModelToggleButton";
import { ModeToggleButton } from "./components/ModeToggleButton";
import { RecordingButton } from "./components/RecordingButton";
import { RecordingIndicator } from "./components/RecordingIndicator";
@@ -50,16 +51,22 @@ export function ChatInput({
onDroppedFilesConsumed,
hasSession = false,
}: Props) {
const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } =
useCopilotUIStore();
const {
copilotChatMode,
setCopilotChatMode,
copilotLlmModel,
setCopilotLlmModel,
isDryRun,
setIsDryRun,
} = useCopilotUIStore();
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
const showDryRunToggle = showModeToggle;
const [files, setFiles] = useState<File[]>([]);
function handleToggleMode() {
const next =
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
setCopilotMode(next);
copilotChatMode === "extended_thinking" ? "fast" : "extended_thinking";
setCopilotChatMode(next);
toast({
title:
next === "fast"
@@ -72,6 +79,21 @@ export function ChatInput({
});
}
function handleToggleModel() {
const next = copilotLlmModel === "advanced" ? "standard" : "advanced";
setCopilotLlmModel(next);
toast({
title:
next === "advanced"
? "Switched to Advanced model"
: "Switched to Standard model",
description:
next === "advanced"
? "Using the highest-capability model."
: "Using the balanced standard model.",
});
}
function handleToggleDryRun() {
const next = !isDryRun;
setIsDryRun(next);
@@ -196,17 +218,28 @@ export function ChatInput({
onFilesSelected={handleFilesSelected}
disabled={isBusy}
/>
{/* Mode and model are per-message settings sent with each stream request,
so they can be freely changed between turns in an existing session.
Hide only while actively streaming (too late to change for that turn). */}
{showModeToggle && !isStreaming && (
<ModeToggleButton
mode={copilotMode}
mode={copilotChatMode}
onToggle={handleToggleMode}
/>
)}
{showDryRunToggle && (!hasSession || isDryRun) && (
{showModeToggle && !isStreaming && (
<ModelToggleButton
model={copilotLlmModel}
onToggle={handleToggleModel}
/>
)}
{/* DryRun button only on new chats: once a session exists its
dry_run flag is locked and should be read from session metadata
(sessionDryRun in useCopilotPage), not toggled here. The banner
in CopilotPage.tsx reflects the actual session state. */}
{showDryRunToggle && !hasSession && (
<DryRunToggleButton
isDryRun={isDryRun}
isStreaming={isStreaming}
readOnly={hasSession}
onToggle={handleToggleDryRun}
/>
)}

View File

@@ -8,14 +8,23 @@ import { afterEach, describe, expect, it, vi } from "vitest";
import { ChatInput } from "../ChatInput";
let mockCopilotMode = "extended_thinking";
const mockSetCopilotMode = vi.fn((mode: string) => {
const mockSetCopilotChatMode = vi.fn((mode: string) => {
mockCopilotMode = mode;
});
let mockCopilotLlmModel = "standard";
const mockSetCopilotLlmModel = vi.fn((model: string) => {
mockCopilotLlmModel = model;
});
vi.mock("@/app/(platform)/copilot/store", () => ({
useCopilotUIStore: () => ({
copilotMode: mockCopilotMode,
setCopilotMode: mockSetCopilotMode,
copilotChatMode: mockCopilotMode,
setCopilotChatMode: mockSetCopilotChatMode,
copilotLlmModel: mockCopilotLlmModel,
setCopilotLlmModel: mockSetCopilotLlmModel,
isDryRun: false,
setIsDryRun: vi.fn(),
initialPrompt: null,
setInitialPrompt: vi.fn(),
}),
@@ -107,6 +116,7 @@ afterEach(() => {
cleanup();
vi.clearAllMocks();
mockCopilotMode = "extended_thinking";
mockCopilotLlmModel = "standard";
});
describe("ChatInput mode toggle", () => {
@@ -141,7 +151,7 @@ describe("ChatInput mode toggle", () => {
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
expect(mockSetCopilotMode).toHaveBeenCalledWith("fast");
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("fast");
});
it("toggles from fast to extended_thinking on click", () => {
@@ -149,7 +159,7 @@ describe("ChatInput mode toggle", () => {
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to extended thinking/i));
expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking");
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
});
it("hides toggle button when streaming", () => {
@@ -158,6 +168,15 @@ describe("ChatInput mode toggle", () => {
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
});
it("shows mode toggle when hasSession is true and not streaming", () => {
// Mode is per-message — can be changed between turns even in an existing session.
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} hasSession />);
expect(
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
).not.toBeNull();
});
it("exposes aria-pressed=true in extended_thinking mode", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
@@ -187,3 +206,93 @@ describe("ChatInput mode toggle", () => {
);
});
});
describe("ChatInput model toggle", () => {
it("renders model toggle button when flag is enabled", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} />);
expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined();
});
it("does not render model toggle when flag is disabled", () => {
mockFlagValue = false;
render(<ChatInput onSend={mockOnSend} />);
expect(
screen.queryByLabelText(/switch to (advanced|standard) model/i),
).toBeNull();
});
it("toggles from standard to advanced on click", () => {
mockFlagValue = true;
mockCopilotLlmModel = "standard";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced");
});
it("toggles from advanced to standard on click", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
});
it("hides model toggle when streaming", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
expect(
screen.queryByLabelText(/switch to (advanced|standard) model/i),
).toBeNull();
});
it("shows model toggle when hasSession is true and not streaming", () => {
// Model is per-message — can be changed between turns even in an existing session.
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} hasSession />);
expect(
screen.queryByLabelText(/switch to (advanced|standard) model/i),
).not.toBeNull();
});
it("hides dry-run toggle when hasSession is true", () => {
// DryRun button is only for new chats — once a session exists its dry_run
// flag is immutable and shown via the CopilotPage banner, not this button.
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} hasSession />);
expect(screen.queryByLabelText(/test mode/i)).toBeNull();
expect(screen.queryByLabelText(/enable test mode/i)).toBeNull();
});
it("shows dry-run toggle when no session", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} />);
expect(screen.getByLabelText(/test mode|enable test mode/i)).toBeTruthy();
});
it("shows a toast when switching to advanced", async () => {
const { toast } = await import("@/components/molecules/Toast/use-toast");
mockFlagValue = true;
mockCopilotLlmModel = "standard";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to advanced model/i),
}),
);
});
it("shows a toast when switching to standard", async () => {
const { toast } = await import("@/components/molecules/Toast/use-toast");
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to standard model/i),
}),
);
});
});

View File

@@ -3,42 +3,34 @@
import { cn } from "@/lib/utils";
import { Flask } from "@phosphor-icons/react";
// This button is only rendered on NEW chats (no active session).
// Once a session exists, it is hidden — the session's dry_run flag is
// immutable and reflected in the banner in CopilotPage.tsx instead.
// Do NOT add readOnly/hasSession handling here; hide it at the call site.
interface Props {
isDryRun: boolean;
isStreaming: boolean;
readOnly?: boolean;
onToggle: () => void;
}
export function DryRunToggleButton({
isDryRun,
isStreaming,
readOnly = false,
onToggle,
}: Props) {
const isDisabled = isStreaming || readOnly;
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
return (
<button
type="button"
aria-pressed={isDryRun}
disabled={isDisabled}
onClick={readOnly ? undefined : onToggle}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isDryRun
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
isDisabled && "cursor-default opacity-70",
)}
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
aria-label={
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
}
title={
readOnly
? "Test mode active for this session"
: isStreaming
? "Cannot change mode while streaming"
: isDryRun
? "Test mode ON — click to disable"
: "Enable Test mode — agents will run as dry-run"
isDryRun
? "Test mode ON — new chats run agents as simulation (click to disable)"
: "Enable Test mode — new chats will run agents as simulation"
}
>
<Flask size={14} />

View File

@@ -0,0 +1,38 @@
"use client";
import { cn } from "@/lib/utils";
import { Cpu } from "@phosphor-icons/react";
import type { CopilotLlmModel } from "../../../store";
interface Props {
model: CopilotLlmModel;
onToggle: () => void;
}
export function ModelToggleButton({ model, onToggle }: Props) {
const isAdvanced = model === "advanced";
return (
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isAdvanced
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
}
title={
isAdvanced
? "Advanced model — highest capability (click to switch to Standard)"
: "Standard model — click to switch to Advanced"
}
>
<Cpu size={14} />
{isAdvanced && "Advanced"}
</button>
);
}

View File

@@ -0,0 +1,41 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { DryRunToggleButton } from "../DryRunToggleButton";
afterEach(cleanup);
// DryRunToggleButton only appears on new chats (no active session).
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
// the button entirely at the ChatInput level when hasSession is true.
describe("DryRunToggleButton", () => {
it("shows Test label when isDryRun is true", () => {
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
expect(screen.getByText("Test")).toBeTruthy();
});
it("shows no text label when isDryRun is false", () => {
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
expect(screen.queryByText("Test")).toBeNull();
});
it("calls onToggle when clicked", () => {
const onToggle = vi.fn();
render(<DryRunToggleButton isDryRun={false} onToggle={onToggle} />);
fireEvent.click(screen.getByRole("button"));
expect(onToggle).toHaveBeenCalledTimes(1);
});
it("sets aria-pressed=true when isDryRun is true", () => {
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
expect(screen.getByRole("button").getAttribute("aria-pressed")).toBe(
"true",
);
});
it("sets aria-pressed=false when isDryRun is false", () => {
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
expect(screen.getByRole("button").getAttribute("aria-pressed")).toBe(
"false",
);
});
});

View File

@@ -0,0 +1,37 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { ModelToggleButton } from "../ModelToggleButton";
afterEach(cleanup);
describe("ModelToggleButton", () => {
it("shows no text label when model is standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
expect(screen.queryByText("Standard")).toBeNull();
expect(screen.queryByText("Advanced")).toBeNull();
});
it("shows Advanced label when model is advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
expect(screen.getByText("Advanced")).toBeTruthy();
});
it("calls onToggle when clicked", () => {
const onToggle = vi.fn();
render(<ModelToggleButton model="standard" onToggle={onToggle} />);
fireEvent.click(screen.getByRole("button"));
expect(onToggle).toHaveBeenCalledTimes(1);
});
it("sets aria-pressed=false for standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Advanced model");
expect(btn.getAttribute("aria-pressed")).toBe("false");
});
it("sets aria-pressed=true for advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Standard model");
expect(btn.getAttribute("aria-pressed")).toBe("true");
});
});

View File

@@ -19,8 +19,16 @@ describe("formatResetTime", () => {
});
it("returns formatted date when over 24 hours away", () => {
const result = formatResetTime("2025-06-17T00:00:00Z", now);
expect(result).toMatch(/Tue/);
const resetsAt = "2025-06-17T00:00:00Z";
const result = formatResetTime(resetsAt, now);
const expected = new Date(resetsAt).toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
expect(result).toBe(expected);
});
it("accepts a Date object for resetsAt", () => {

View File

@@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation";
import { getWebSocketToken } from "@/lib/supabase/actions";
import type { UIMessage } from "ai";
import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat";
export const ORIGINAL_TITLE = "AutoGPT";
/**
@@ -50,6 +52,24 @@ export function parseSessionIDs(raw: string | null | undefined): Set<string> {
}
}
/**
* Resolve the actual dry_run value for a session from the raw API response.
* Returns true only when the session response is a 200 with metadata.dry_run === true.
* Returns false for missing/non-200 responses so callers never show a stale
* preference value when the real session state is unknown.
*/
export function resolveSessionDryRun(queryData: unknown): boolean {
if (
queryData == null ||
typeof queryData !== "object" ||
!("status" in queryData) ||
(queryData as { status: unknown }).status !== 200
)
return false;
const d = queryData as { data?: { metadata?: { dry_run?: unknown } } };
return d.data?.metadata?.dry_run === true;
}
/**
* Check whether a refetchSession result indicates the backend still has an
* active SSE stream for this session.
@@ -154,7 +174,18 @@ export function shouldSuppressDuplicateSend(
}
/**
* Deduplicate messages by ID and by content fingerprint.
* Fire-and-forget: tell the backend to release XREAD listeners for a session.
*
* Called on session switch so the backend doesn't wait for its 5-10 s timeout
* before cleaning up. Failures are silently ignored — the backend will
* eventually clean up on its own.
*/
export function disconnectSessionStream(sessionId: string): void {
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
}
/**
* Deduplicate messages by ID and by consecutive content fingerprint.
*
* ID dedup catches exact duplicates within the same source.
* Content dedup uses a composite key of `role + preceding-user-message-id +

View File

@@ -99,6 +99,50 @@ describe("artifactPanel store actions", () => {
expect(s.history).toEqual([]);
});
it("openArtifact does not resurrect a previously closed artifact into history", () => {
const a = makeArtifact("a");
const b = makeArtifact("b");
useCopilotUIStore.getState().openArtifact(a);
useCopilotUIStore.getState().closeArtifactPanel();
useCopilotUIStore.getState().openArtifact(b);
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(true);
expect(s.activeArtifact?.id).toBe("b");
expect(s.history).toEqual([]);
});
it("openArtifact ignores non-previewable artifacts", () => {
const binary = {
...makeArtifact("bin", "artifact.bin"),
mimeType: "application/octet-stream",
};
useCopilotUIStore.getState().openArtifact(binary);
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
it("resetArtifactPanel clears active artifact and history", () => {
const a = makeArtifact("a");
const b = makeArtifact("b");
useCopilotUIStore.getState().openArtifact(a);
useCopilotUIStore.getState().openArtifact(b);
useCopilotUIStore.getState().maximizeArtifactPanel();
useCopilotUIStore.getState().resetArtifactPanel();
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.isMinimized).toBe(false);
expect(s.isMaximized).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
it("minimize/restore toggles isMinimized without touching activeArtifact", () => {
const a = makeArtifact("a");
useCopilotUIStore.getState().openArtifact(a);
@@ -138,4 +182,35 @@ describe("artifactPanel store actions", () => {
expect(s.width).toBe(720);
expect(s.isMaximized).toBe(false);
});
it("history is capped at 25 entries (MAX_HISTORY)", () => {
// Open 27 artifacts sequentially (A0..A26). History should never exceed 25.
for (let i = 0; i < 27; i++) {
useCopilotUIStore.getState().openArtifact(makeArtifact(`a${i}`));
}
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.activeArtifact?.id).toBe("a26");
expect(s.history.length).toBe(25);
// The oldest entry (a0) should have been dropped; a1 is the earliest surviving.
expect(s.history[0].id).toBe("a1");
expect(s.history[24].id).toBe("a25");
});
it("clearCopilotLocalData resets artifact panel to default", () => {
const a = makeArtifact("a");
const b = makeArtifact("b");
useCopilotUIStore.getState().openArtifact(a);
useCopilotUIStore.getState().openArtifact(b);
useCopilotUIStore.getState().maximizeArtifactPanel();
useCopilotUIStore.getState().clearCopilotLocalData();
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.isMinimized).toBe(false);
expect(s.isMaximized).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
expect(s.width).toBe(600); // DEFAULT_PANEL_WIDTH
});
});

View File

@@ -1,6 +1,7 @@
import { Key, storage } from "@/services/storage/local-storage";
import { create } from "zustand";
import { clearContentCache } from "./components/ArtifactPanel/components/useArtifactContent";
import { classifyArtifact } from "./components/ArtifactPanel/helpers";
import { ORIGINAL_TITLE, parseSessionIDs } from "./helpers";
export interface DeleteTarget {
@@ -52,6 +53,9 @@ export const DEFAULT_PANEL_WIDTH = 600;
/** Autopilot response mode. */
export type CopilotMode = "extended_thinking" | "fast";
/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */
export type CopilotLlmModel = "standard" | "advanced";
const isClient = typeof window !== "undefined";
function getPersistedWidth(): number {
@@ -92,6 +96,10 @@ function persistCompletedSessions(ids: Set<string>) {
}
}
function isPreviewableArtifact(ref: ArtifactRef): boolean {
return classifyArtifact(ref.mimeType, ref.title, ref.sizeBytes).openable;
}
interface CopilotUIState {
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
initialPrompt: string | null;
@@ -121,6 +129,7 @@ interface CopilotUIState {
artifactPanel: ArtifactPanelState;
openArtifact: (ref: ArtifactRef) => void;
closeArtifactPanel: () => void;
resetArtifactPanel: () => void;
minimizeArtifactPanel: () => void;
maximizeArtifactPanel: () => void;
restoreArtifactPanel: () => void;
@@ -128,8 +137,12 @@ interface CopilotUIState {
goBackArtifact: () => void;
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
copilotMode: CopilotMode;
setCopilotMode: (mode: CopilotMode) => void;
copilotChatMode: CopilotMode;
setCopilotChatMode: (mode: CopilotMode) => void;
/** Model tier: 'standard' (default) or 'advanced' (highest-capability). */
copilotLlmModel: CopilotLlmModel;
setCopilotLlmModel: (model: CopilotLlmModel) => void;
/** Developer dry-run mode: sessions created with dry_run=true. */
isDryRun: boolean;
@@ -203,14 +216,20 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
},
openArtifact: (ref) =>
set((state) => {
if (!isPreviewableArtifact(ref)) return state;
const { activeArtifact, history: prevHistory } = state.artifactPanel;
const topOfHistory = prevHistory[prevHistory.length - 1];
const isReturningToTop = topOfHistory?.id === ref.id;
const shouldPushHistory =
state.artifactPanel.isOpen &&
activeArtifact != null &&
activeArtifact.id !== ref.id;
const MAX_HISTORY = 25;
const history = isReturningToTop
? prevHistory.slice(0, -1)
: activeArtifact && activeArtifact.id !== ref.id
? [...prevHistory, activeArtifact].slice(-MAX_HISTORY)
: shouldPushHistory
? [...prevHistory, activeArtifact!].slice(-MAX_HISTORY)
: prevHistory;
return {
artifactPanel: {
@@ -231,6 +250,17 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
history: [],
},
})),
resetArtifactPanel: () =>
set((state) => ({
artifactPanel: {
...state.artifactPanel,
isOpen: false,
isMinimized: false,
isMaximized: false,
activeArtifact: null,
history: [],
},
})),
minimizeArtifactPanel: () =>
set((state) => ({
artifactPanel: { ...state.artifactPanel, isMinimized: true },
@@ -275,9 +305,22 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
};
}),
copilotMode: "extended_thinking",
setCopilotMode: (mode) => {
set({ copilotMode: mode });
copilotChatMode: (() => {
const saved = isClient ? storage.get(Key.COPILOT_MODE) : null;
return saved === "fast" ? "fast" : "extended_thinking";
})(),
setCopilotChatMode: (mode) => {
storage.set(Key.COPILOT_MODE, mode);
set({ copilotChatMode: mode });
},
copilotLlmModel: (() => {
const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null;
return saved === "advanced" ? "advanced" : "standard";
})(),
setCopilotLlmModel: (model) => {
storage.set(Key.COPILOT_MODEL, model);
set({ copilotLlmModel: model });
},
isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true",
@@ -299,6 +342,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH);
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
storage.clean(Key.COPILOT_DRY_RUN);
storage.clean(Key.COPILOT_MODE);
storage.clean(Key.COPILOT_MODEL);
set({
completedSessionIDs: new Set<string>(),
isNotificationsEnabled: false,
@@ -311,7 +356,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
activeArtifact: null,
history: [],
},
copilotMode: "extended_thinking",
copilotChatMode: "extended_thinking",
copilotLlmModel: "standard",
isDryRun: false,
});
if (isClient) {

View File

@@ -1,15 +1,13 @@
"use client";
import React, { useState } from "react";
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import { Button } from "@/components/atoms/Button/Button";
import type { BlockOutputResponse } from "@/app/api/__generated__/models/blockOutputResponse";
import {
globalRegistry,
OutputItem,
} from "@/components/contextual/OutputRenderers";
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
import { isWorkspaceURI, parseWorkspaceURI } from "@/lib/workspace-uri";
import { resolveForRenderer } from "@/app/(platform)/copilot/tools/ViewAgentOutput/ViewAgentOutput";
import {
ContentBadge,
ContentCard,
@@ -24,28 +22,6 @@ interface Props {
const COLLAPSED_LIMIT = 3;
function resolveForRenderer(value: unknown): {
value: unknown;
metadata?: OutputMetadata;
} {
if (!isWorkspaceURI(value)) return { value };
const parsed = parseWorkspaceURI(value);
if (!parsed) return { value };
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
const url = `/api/proxy${apiPath}`;
const metadata: OutputMetadata = {};
if (parsed.mimeType) {
metadata.mimeType = parsed.mimeType;
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
}
return { value: url, metadata };
}
function RenderOutputValue({ value }: { value: unknown }) {
const resolved = resolveForRenderer(value);
const renderer = globalRegistry.getRenderer(
@@ -63,16 +39,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
);
}
// Fallback for audio workspace refs
if (
isWorkspaceURI(value) &&
resolved.metadata?.mimeType?.startsWith("audio/")
) {
return (
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
);
}
return null;
}

View File

@@ -2,7 +2,6 @@
import type { ToolUIPart } from "ai";
import React from "react";
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import {
globalRegistry,
OutputItem,
@@ -47,7 +46,7 @@ interface Props {
part: ViewAgentOutputToolPart;
}
function resolveForRenderer(value: unknown): {
export function resolveForRenderer(value: unknown): {
value: unknown;
metadata?: OutputMetadata;
} {
@@ -56,17 +55,17 @@ function resolveForRenderer(value: unknown): {
const parsed = parseWorkspaceURI(value);
if (!parsed) return { value };
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
const url = `/api/proxy${apiPath}`;
// Pass workspace URIs through to the registry unchanged.
// WorkspaceFileRenderer (priority 50) matches workspace:// URIs and
// handles URL building, loading skeletons, and error states internally.
// Previously this converted to a proxy URL which bypassed
// WorkspaceFileRenderer, causing ImageRenderer (bare <img>) to match.
const metadata: OutputMetadata = {};
if (parsed.mimeType) {
metadata.mimeType = parsed.mimeType;
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
}
return { value: url, metadata };
return { value, metadata };
}
function RenderOutputValue({ value }: { value: unknown }) {
@@ -86,16 +85,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
);
}
// Fallback for audio workspace refs
if (
isWorkspaceURI(value) &&
resolved.metadata?.mimeType?.startsWith("audio/")
) {
return (
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
);
}
return null;
}

View File

@@ -0,0 +1,52 @@
import { describe, expect, it } from "vitest";
import { resolveForRenderer } from "../ViewAgentOutput";
import { globalRegistry } from "@/components/contextual/OutputRenderers";
describe("resolveForRenderer", () => {
it("preserves workspace image URI for the registry to handle", () => {
const result = resolveForRenderer("workspace://abc123#image/png");
expect(String(result.value)).toMatch(/^workspace:\/\//);
expect(result.metadata?.mimeType).toBe("image/png");
});
it("preserves workspace video URI for the registry to handle", () => {
const result = resolveForRenderer("workspace://vid456#video/mp4");
expect(String(result.value)).toMatch(/^workspace:\/\//);
expect(result.metadata?.mimeType).toBe("video/mp4");
});
it("passes non-workspace values through unchanged", () => {
const result = resolveForRenderer("just a string");
expect(result.value).toBe("just a string");
expect(result.metadata).toBeUndefined();
});
it("passes non-string values through unchanged", () => {
const obj = { foo: "bar" };
const result = resolveForRenderer(obj);
expect(result.value).toBe(obj);
expect(result.metadata).toBeUndefined();
});
it("workspace image URIs match WorkspaceFileRenderer with loading/error states", () => {
// WorkspaceFileRenderer (priority 50) should handle workspace:// URIs
// since resolveForRenderer no longer pre-converts them to proxy URLs.
const resolved = resolveForRenderer("workspace://abc123#image/png");
const renderer = globalRegistry.getRenderer(
resolved.value,
resolved.metadata,
);
expect(renderer).toBeDefined();
expect(renderer!.name).toBe("WorkspaceFileRenderer");
});
it("workspace video URIs match WorkspaceFileRenderer", () => {
const resolved = resolveForRenderer("workspace://vid456#video/mp4");
const renderer = globalRegistry.getRenderer(
resolved.value,
resolved.metadata,
);
expect(renderer).toBeDefined();
expect(renderer!.name).toBe("WorkspaceFileRenderer");
});
});

View File

@@ -10,6 +10,7 @@ import { useQueryClient } from "@tanstack/react-query";
import { parseAsString, useQueryState } from "nuqs";
import { useEffect, useMemo, useRef } from "react";
import { convertChatSessionMessagesToUiMessages } from "./helpers/convertChatSessionToUiMessages";
import { resolveSessionDryRun } from "./helpers";
interface UseChatSessionOptions {
dryRun?: boolean;
@@ -163,6 +164,18 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
? ((sessionQuery.data.data.messages ?? []) as unknown[])
: [];
// The actual dry_run value stored in the session's metadata, read directly
// from the API response. This reflects what the session was ACTUALLY created
// with — not the user's current UI preference (isDryRun store).
//
// Design intent: the global isDryRun store is only used when creating NEW
// sessions. Once a session exists, its dry_run flag is immutable and should
// be read from here rather than from the store, which may have changed.
const sessionDryRun = useMemo(
() => resolveSessionDryRun(sessionQuery.data),
[sessionQuery.data],
);
return {
sessionId,
setSessionId,
@@ -177,5 +190,6 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
createSession,
isCreatingSession,
refetchSession: sessionQuery.refetch,
sessionDryRun,
};
}

View File

@@ -42,7 +42,8 @@ export function useCopilotPage() {
setSessionToDelete,
isDrawerOpen,
setDrawerOpen,
copilotMode,
copilotChatMode,
copilotLlmModel,
isDryRun,
} = useCopilotUIStore();
@@ -60,6 +61,7 @@ export function useCopilotPage() {
createSession,
isCreatingSession,
refetchSession,
sessionDryRun,
} = useChatSession({ dryRun: isDryRun });
const {
@@ -78,7 +80,8 @@ export function useCopilotPage() {
hydratedMessages,
hasActiveStream,
refetchSession,
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
copilotMode: isModeToggleEnabled ? copilotChatMode : undefined,
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
});
const { olderMessages, hasMore, isLoadingMore, loadMore } =
@@ -416,6 +419,11 @@ export function useCopilotPage() {
rateLimitMessage,
dismissRateLimit,
// Dry run dev toggle
// isDryRun = global preference for NEW sessions (from localStorage).
// sessionDryRun = actual dry_run value of the CURRENT session (from API).
// Use isDryRun to configure future sessions; use sessionDryRun to display
// the current session's simulation state (banner, indicators).
isDryRun,
sessionDryRun,
};
}

Some files were not shown because too many files have changed in this diff Show More