mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
38 Commits
feat/incre
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42ccef316a | ||
|
|
26b5f9958b | ||
|
|
a281e38620 | ||
|
|
66cda847c7 | ||
|
|
f72bddb51e | ||
|
|
a96517ccb5 | ||
|
|
9493c55108 | ||
|
|
d23ca824ad | ||
|
|
4ab5c64c09 | ||
|
|
227c60abd3 | ||
|
|
d0592d63a6 | ||
|
|
fa064aa4f1 | ||
|
|
6bdfadf903 | ||
|
|
737aa20f80 | ||
|
|
5f79164c53 | ||
|
|
ff65f58ba9 | ||
|
|
cc89f245ce | ||
|
|
8202f48e46 | ||
|
|
43e8159822 | ||
|
|
25e34829bc | ||
|
|
6a091a17d2 | ||
|
|
0284614df0 | ||
|
|
5cfb6ffdaa | ||
|
|
f49a9f728c | ||
|
|
53925d2e2b | ||
|
|
aa2d2d7371 | ||
|
|
661fffe133 | ||
|
|
f835674498 | ||
|
|
da18f372f7 | ||
|
|
d82ecac363 | ||
|
|
8a2e2365f7 | ||
|
|
55869d3c75 | ||
|
|
142c5dbe99 | ||
|
|
b06648de8c | ||
|
|
7240dd4fb1 | ||
|
|
b4cd00bea9 | ||
|
|
e17914d393 | ||
|
|
b3a58389e5 |
@@ -48,14 +48,15 @@ git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Hooks with non-trivial business logic
|
||||
3. Shared hooks with standalone business logic when UI-level coverage is impractical
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
@@ -163,6 +164,7 @@ describe("LibraryPage", () => {
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
@@ -190,9 +192,7 @@ import { http, HttpResponse } from "msw";
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [
|
||||
{ id: "1", name: "Test Agent", description: "A test agent" },
|
||||
],
|
||||
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
@@ -211,6 +211,7 @@ pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
13
.github/workflows/platform-fullstack-ci.yml
vendored
13
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -160,6 +160,7 @@ jobs:
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -288,6 +289,14 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Cache Playwright browsers
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
playwright-${{ runner.os }}-
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
@@ -299,8 +308,8 @@ jobs:
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
- name: Run Playwright E2E suite
|
||||
run: pnpm test:e2e:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -194,3 +194,4 @@ test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
|
||||
166
autogpt_platform/backend/agents/calculator-agent.json
Normal file
166
autogpt_platform/backend/agents/calculator-agent.json
Normal file
@@ -0,0 +1,166 @@
|
||||
{
|
||||
"id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"version": 2,
|
||||
"is_active": true,
|
||||
"name": "Calculator agent",
|
||||
"description": "",
|
||||
"instructions": null,
|
||||
"recommended_schedule_cron": null,
|
||||
"forked_from_id": null,
|
||||
"forked_from_version": null,
|
||||
"user_id": "",
|
||||
"created_at": "2026-04-13T03:45:11.241Z",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"input_default": {
|
||||
"name": "Input",
|
||||
"secret": false,
|
||||
"advanced": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": -188.2244873046875,
|
||||
"y": 95
|
||||
}
|
||||
},
|
||||
"input_links": [],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
"input_default": {
|
||||
"name": "Output",
|
||||
"secret": false,
|
||||
"advanced": false,
|
||||
"escape_html": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 825.198974609375,
|
||||
"y": 123.75
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"output_links": [],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
|
||||
"input_default": {
|
||||
"b": 34,
|
||||
"operation": "Add",
|
||||
"round_result": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 323.0255126953125,
|
||||
"y": 121.25
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
},
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"sub_graphs": [],
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Input": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Input"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Input"
|
||||
]
|
||||
},
|
||||
"output_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Output": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Output"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Output"
|
||||
]
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ async def get_cost_dashboard(
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
return await get_platform_cost_dashboard(
|
||||
@@ -53,6 +54,7 @@ async def get_cost_dashboard(
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -72,6 +74,7 @@ async def get_cost_logs(
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
@@ -84,6 +87,7 @@ async def get_cost_logs(
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
@@ -117,6 +121,7 @@ async def export_cost_logs(
|
||||
model: str | None = Query(None),
|
||||
block_name: str | None = Query(None),
|
||||
tracking_type: str | None = Query(None),
|
||||
graph_exec_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s exporting platform cost logs", admin_user_id)
|
||||
logs, truncated = await get_platform_cost_logs_for_export(
|
||||
@@ -127,6 +132,7 @@ async def export_cost_logs(
|
||||
model=model,
|
||||
block_name=block_name,
|
||||
tracking_type=tracking_type,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
return PlatformCostExportResponse(
|
||||
logs=logs,
|
||||
|
||||
@@ -15,9 +15,10 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -139,6 +140,11 @@ class StreamChatRequest(BaseModel):
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
model: CopilotLlmModel | None = Field(
|
||||
default=None,
|
||||
description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. "
|
||||
"If None, the server applies per-user LD targeting then falls back to config.",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -376,6 +382,31 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sessions/{session_id}/stream",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=204,
|
||||
)
|
||||
async def disconnect_session_stream(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> Response:
|
||||
"""Disconnect all active SSE listeners for a session.
|
||||
|
||||
Called by the frontend when the user switches away from a chat so the
|
||||
backend releases XREAD listeners immediately rather than waiting for
|
||||
the 5-10 s timeout.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
await stream_registry.disconnect_all_listeners(session_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
@@ -810,6 +841,9 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
||||
# so the idempotency hash is stable across retries.
|
||||
original_message = request.message
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
@@ -838,60 +872,91 @@ async def stream_chat_post(
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
||||
dedup_lock = None
|
||||
if request.is_user_message:
|
||||
dedup_lock = await acquire_dedup_lock(
|
||||
session_id, original_message, sanitized_file_ids
|
||||
)
|
||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
||||
|
||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_empty_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
#
|
||||
# If any of these operations raises, release the dedup lock before propagating
|
||||
# so subsequent retries are not blocked for 30 s.
|
||||
try:
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception:
|
||||
if dedup_lock:
|
||||
await dedup_lock.release()
|
||||
raise
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -899,6 +964,9 @@ async def stream_chat_post(
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
@@ -912,6 +980,12 @@ async def stream_chat_post(
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
# True for every exit path except GeneratorExit (client disconnect).
|
||||
# On disconnect the backend turn is still running — releasing the lock
|
||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
||||
# is the fallback. All other exits (normal finish, early return, error)
|
||||
# should release so the user can re-send the same message.
|
||||
release_dedup_lock_on_exit = True
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
@@ -923,8 +997,7 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
return # finally releases dedup_lock
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
@@ -953,7 +1026,6 @@ async def stream_chat_post(
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
@@ -967,7 +1039,8 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
break # finally releases dedup_lock
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
@@ -982,7 +1055,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
release_dedup_lock_on_exit = False
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -997,7 +1070,10 @@ async def stream_chat_post(
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
# finally releases dedup_lock
|
||||
finally:
|
||||
if dedup_lock and release_dedup_lock_on_exit:
|
||||
await dedup_lock.release()
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
|
||||
@@ -133,14 +133,30 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
def _mock_stream_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
redis_set_returns: object = True,
|
||||
):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
validation and enrichment logic without needing Redis/RabbitMQ.
|
||||
|
||||
Args:
|
||||
redis_set_returns: Value returned by the mocked Redis ``set`` call.
|
||||
``True`` (default) simulates a fresh key (new message);
|
||||
``None`` simulates a collision (duplicate blocked).
|
||||
|
||||
Returns:
|
||||
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
@@ -150,7 +166,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
mock_enqueue = mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
@@ -158,9 +174,18 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=redis_set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
|
||||
return ns
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
@@ -189,7 +214,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
@@ -228,7 +253,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
@@ -257,7 +282,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -278,7 +303,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -301,7 +328,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -677,3 +704,279 @@ class TestStripInjectedContext:
|
||||
result = _strip_injected_context(msg)
|
||||
# Without a role, the helper short-circuits without touching content.
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A second POST with the same message within the 30-s window must return
|
||||
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
|
||||
turn complete without creating a ghost response."""
|
||||
# redis_set_returns=None simulates a collision: the NX key already exists.
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-dup/stream",
|
||||
json={"message": "duplicate message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
|
||||
assert '"finish"' in body
|
||||
assert "[DONE]" in body
|
||||
# The empty SSE response must include the AI SDK protocol header so the
|
||||
# frontend treats it as a valid stream and marks the turn complete.
|
||||
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
|
||||
# The duplicate guard must prevent save/enqueue side effects.
|
||||
ns.save.assert_not_called()
|
||||
ns.enqueue.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_first_post_proceeds_normally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The first POST (Redis NX key set successfully) must proceed through the
|
||||
normal streaming path — no early return."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-new/stream",
|
||||
json={"message": "first message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Redis set must have been called once with the NX flag.
|
||||
ns.redis.set.assert_called_once()
|
||||
call_kwargs = ns.redis.set.call_args
|
||||
assert call_kwargs.kwargs.get("nx") is True
|
||||
|
||||
|
||||
def test_stream_chat_dedup_skipped_for_non_user_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""System/assistant messages (is_user_message=False) bypass the dedup
|
||||
guard — they are injected programmatically and must always be processed."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-sys/stream",
|
||||
json={"message": "system context", "is_user_message": False},
|
||||
)
|
||||
|
||||
# Even though redis_set_returns=None (would block a user message),
|
||||
# the endpoint must proceed because is_user_message=False.
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup hash must be computed from the original request message,
|
||||
not the mutated version that has the [Attached files] block appended.
|
||||
A file_id is sent so the route actually appends the [Attached files] block,
|
||||
exercising the mutation path — the hash must still match the original text."""
|
||||
import hashlib
|
||||
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
# Mock workspace + prisma so the attachment block is actually appended.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
fake_file = type(
|
||||
"F",
|
||||
(),
|
||||
{
|
||||
"id": file_id,
|
||||
"name": "doc.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"sizeBytes": 1024,
|
||||
},
|
||||
)()
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-hash/stream",
|
||||
json={
|
||||
"message": "plain message",
|
||||
"is_user_message": True,
|
||||
"file_ids": [file_id],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_called_once()
|
||||
call_args = ns.redis.set.call_args
|
||||
dedup_key = call_args.args[0]
|
||||
|
||||
# Hash must use the original message + sorted file IDs, not the mutated text.
|
||||
expected_hash = hashlib.sha256(
|
||||
f"sess-hash:plain message:{file_id}".encode()
|
||||
).hexdigest()[:16]
|
||||
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
|
||||
assert dedup_key == expected_key, (
|
||||
f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — "
|
||||
"hash may be using mutated message or wrong inputs"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_after_stream_finish(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup Redis key must be deleted after the turn completes (when
|
||||
subscriber_queue is None the route yields StreamFinish immediately and
|
||||
should release the key so the user can re-send the same message)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
# Set up all internals manually so we can control subscribe_to_session.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-finish/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
assert '"finish"' in body
|
||||
# The dedup key must be released so intentional re-sends are allowed.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The route must not crash when the dedup Redis delete fails on the
|
||||
subscriber_queue-is-None early-finish path (except Exception: pass)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
# Make the delete raise so the except-pass branch is exercised.
|
||||
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
# Should not raise even though delete fails.
|
||||
response = client.post(
|
||||
"/sessions/sess-finish-err/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '"finish"' in response.text
|
||||
# delete must have been attempted — the except-pass branch silenced the error.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_204_and_awaits_registry(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/sess-1/stream")
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_disconnect.assert_awaited_once_with("sess-1")
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_404_when_session_missing(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/unknown-session/stream")
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
@@ -25,6 +25,7 @@ from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -43,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||
from backend.data.model import ContributorDetails
|
||||
|
||||
from ..data.graph import Link
|
||||
|
||||
@@ -420,6 +421,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Return extra runtime cost to charge after this block run completes.
|
||||
|
||||
Called by the executor after a block finishes with COMPLETED status.
|
||||
The return value is the number of additional base-cost credits to
|
||||
charge beyond the single credit already collected by charge_usage
|
||||
at the start of execution. Defaults to 0 (no extra charges).
|
||||
|
||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
||||
calls within one run and should be billed per call.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -455,8 +469,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
disabled: If the block is disabled, it will not be available for execution.
|
||||
static_output: Whether the output links of the block are static by default.
|
||||
"""
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
self.id = id
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
@@ -474,7 +486,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -554,7 +566,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
self.execution_stats += stats
|
||||
return self.execution_stats
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
from backend.util.security import SENSITIVE_FIELD_NAMES
|
||||
from backend.util.tool_call_loop import (
|
||||
@@ -364,10 +365,31 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
class OrchestratorBlock(Block):
|
||||
"""A block that uses a language model to orchestrate tool calls.
|
||||
|
||||
Supports both single-shot and iterative agent mode execution.
|
||||
|
||||
**InsufficientBalanceError propagation contract**: ``InsufficientBalanceError``
|
||||
(IBE) must always re-raise through every ``except`` block in this class.
|
||||
Swallowing IBE would let the agent loop continue with unpaid work. Every
|
||||
exception handler that catches ``Exception`` includes an explicit IBE
|
||||
re-raise carve-out for this reason.
|
||||
"""
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
single-shot and iterative agent mode execution.
|
||||
"""
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Charge one extra runtime cost per LLM call beyond the first.
|
||||
|
||||
In agent mode each iteration makes one LLM call. The first is already
|
||||
covered by charge_usage(); this returns the number of additional
|
||||
credits so the executor can bill the remaining calls post-completion.
|
||||
|
||||
SDK-mode exemption: when the block runs via _execute_tools_sdk_mode,
|
||||
the SDK manages its own conversation loop and only exposes aggregate
|
||||
usage. We hardcode llm_call_count=1 there (the SDK does not report a
|
||||
per-turn call count), so this method always returns 0 for SDK-mode
|
||||
executions. Per-iteration billing does not apply to SDK mode.
|
||||
"""
|
||||
return max(0, execution_stats.llm_call_count - 1)
|
||||
|
||||
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
|
||||
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
|
||||
@@ -1077,7 +1099,10 @@ class OrchestratorBlock(Block):
|
||||
input_data=input_value,
|
||||
)
|
||||
|
||||
assert node_exec_result is not None, "node_exec_result should not be None"
|
||||
if node_exec_result is None:
|
||||
raise RuntimeError(
|
||||
f"upsert_execution_input returned None for node {sink_node_id}"
|
||||
)
|
||||
|
||||
# Create NodeExecutionEntry for execution manager
|
||||
node_exec_entry = NodeExecutionEntry(
|
||||
@@ -1112,15 +1137,86 @@ class OrchestratorBlock(Block):
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
# Execute the node directly since we're in the Orchestrator context.
|
||||
# Wrap in try/except so the future is always resolved, even on
|
||||
# error — an unresolved Future would block anything awaiting it.
|
||||
#
|
||||
# on_node_execution is decorated with @async_error_logged(swallow=True),
|
||||
# which catches BaseException and returns None rather than raising.
|
||||
# Treat a None return as a failure: set_exception so the future
|
||||
# carries an error state rather than a None result, and return an
|
||||
# error response so the LLM knows the tool failed.
|
||||
try:
|
||||
tool_node_stats = await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
node_exec_progress=node_exec_progress,
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=graph_stats_pair,
|
||||
)
|
||||
)
|
||||
if tool_node_stats is None:
|
||||
nil_err = RuntimeError(
|
||||
f"on_node_execution returned None for node {sink_node_id} "
|
||||
"(error was swallowed by @async_error_logged)"
|
||||
)
|
||||
node_exec_future.set_exception(nil_err)
|
||||
resp = _create_tool_response(
|
||||
tool_call.id,
|
||||
"Tool execution returned no result",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
resp["_is_error"] = True
|
||||
return resp
|
||||
node_exec_future.set_result(tool_node_stats)
|
||||
except Exception as exec_err:
|
||||
node_exec_future.set_exception(exec_err)
|
||||
raise
|
||||
|
||||
# Charge user credits AFTER successful tool execution. Tools
|
||||
# spawned by the orchestrator bypass the main execution queue
|
||||
# (where _charge_usage is called), so we must charge here to
|
||||
# avoid free tool execution. Charging post-completion (vs.
|
||||
# pre-execution) avoids billing users for failed tool calls.
|
||||
# Skipped for dry runs.
|
||||
#
|
||||
# `error is None` intentionally excludes both Exception and
|
||||
# BaseException subclasses (e.g. CancelledError) so cancelled
|
||||
# or terminated tool runs are not billed.
|
||||
#
|
||||
# Billing errors (including non-balance exceptions) are kept
|
||||
# in a separate try/except so they are never silently swallowed
|
||||
# by the generic tool-error handler below.
|
||||
if (
|
||||
not execution_params.execution_context.dry_run
|
||||
and tool_node_stats.error is None
|
||||
):
|
||||
try:
|
||||
tool_cost, _ = await execution_processor.charge_node_usage(
|
||||
node_exec_entry,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see OrchestratorBlock class docstring.
|
||||
# Log the billing failure here so the discarded tool result
|
||||
# is traceable before the loop aborts.
|
||||
logger.warning(
|
||||
"Insufficient balance charging for tool node %s after "
|
||||
"successful execution; agent loop will be aborted",
|
||||
sink_node_id,
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
# Non-billing charge failures (DB outage, network, etc.)
|
||||
# must NOT propagate to the outer except handler because
|
||||
# the tool itself succeeded. Re-raising would mark the
|
||||
# tool as failed (_is_error=True), causing the LLM to
|
||||
# retry side-effectful operations. Log and continue.
|
||||
logger.exception(
|
||||
"Unexpected error charging for tool node %s; "
|
||||
"tool execution was successful",
|
||||
sink_node_id,
|
||||
)
|
||||
tool_cost = 0
|
||||
if tool_cost > 0:
|
||||
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
|
||||
|
||||
# Get outputs from database after execution completes using database manager client
|
||||
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||
@@ -1133,18 +1229,26 @@ class OrchestratorBlock(Block):
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(
|
||||
resp = _create_tool_response(
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
resp["_is_error"] = False
|
||||
return resp
|
||||
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning("Tool execution with manager failed: %s", e)
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
logger.warning("Tool execution with manager failed: %s", e, exc_info=True)
|
||||
# Return a generic error to the LLM — internal exception messages
|
||||
# may contain server paths, DB details, or infrastructure info.
|
||||
resp = _create_tool_response(
|
||||
tool_call.id,
|
||||
f"Tool execution failed: {e}",
|
||||
"Tool execution failed due to an internal error",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
resp["_is_error"] = True
|
||||
return resp
|
||||
|
||||
async def _agent_mode_llm_caller(
|
||||
self,
|
||||
@@ -1244,13 +1348,16 @@ class OrchestratorBlock(Block):
|
||||
content = str(raw_content)
|
||||
else:
|
||||
content = "Tool executed successfully"
|
||||
tool_failed = content.startswith("Tool execution failed:")
|
||||
tool_failed = result.get("_is_error", True)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
content=content,
|
||||
is_error=tool_failed,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
return ToolCallResult(
|
||||
@@ -1370,9 +1477,13 @@ class OrchestratorBlock(Block):
|
||||
"arguments": tc.arguments,
|
||||
},
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all errors (validation, network, API) so that the block
|
||||
# surfaces them as user-visible output instead of crashing.
|
||||
# Catch all OTHER errors (validation, network, API) so that
|
||||
# the block surfaces them as user-visible output instead of
|
||||
# crashing.
|
||||
yield "error", str(e)
|
||||
return
|
||||
|
||||
@@ -1450,11 +1561,14 @@ class OrchestratorBlock(Block):
|
||||
text = content
|
||||
else:
|
||||
text = json.dumps(content)
|
||||
tool_failed = text.startswith("Tool execution failed:")
|
||||
tool_failed = result.get("_is_error", True)
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": tool_failed,
|
||||
}
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("SDK tool execution failed: %s", e)
|
||||
return {
|
||||
@@ -1733,11 +1847,15 @@ class OrchestratorBlock(Block):
|
||||
await pending_task
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
pass
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring. The `finally`
|
||||
# block below still runs and records partial token usage.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Surface SDK errors as user-visible output instead of crashing,
|
||||
# consistent with _execute_tools_agent_mode error handling.
|
||||
# Don't return yet — fall through to merge_stats below so
|
||||
# partial token usage is always recorded.
|
||||
# Surface OTHER SDK errors as user-visible output instead
|
||||
# of crashing, consistent with _execute_tools_agent_mode
|
||||
# error handling. Don't return yet — fall through to
|
||||
# merge_stats below so partial token usage is always recorded.
|
||||
sdk_error = e
|
||||
finally:
|
||||
# Always record usage stats, even on error. The SDK may have
|
||||
|
||||
@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Returns (cost, remaining_balance). Must be AsyncMock because it is
|
||||
# an async method and is directly awaited in _execute_single_tool_with_manager.
|
||||
# Use a non-zero cost so the merge_stats branch is exercised.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
|
||||
# Verify tool was executed via execution processor
|
||||
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||
|
||||
# Verify charge_node_usage was actually called for the successful
|
||||
# tool execution — this guards against regressions where the
|
||||
# post-execution tool charging is accidentally removed.
|
||||
assert mock_execution_processor.charge_node_usage.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
|
||||
@@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
mock_execution_processor.on_node_execution.return_value = (
|
||||
mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would
|
||||
# return a non-awaitable tuple and TypeError out, then be
|
||||
# silently swallowed by the orchestrator's catch-all.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(
|
||||
return_value=(0, 0)
|
||||
)
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
ep.execution_stats_lock = threading.Lock()
|
||||
ns = MagicMock(error=None)
|
||||
ep.on_node_execution = AsyncMock(return_value=ns)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would return a
|
||||
# non-awaitable tuple and TypeError out, then be silently swallowed by
|
||||
# the orchestrator's catch-all.
|
||||
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
|
||||
@@ -103,6 +103,7 @@ _TRANSCRIPT_UPLOAD_TIMEOUT_S = 5
|
||||
# MIME types that can be embedded as vision content blocks (OpenAI format).
|
||||
_VISION_MIME_TYPES = frozenset({"image/png", "image/jpeg", "image/gif", "image/webp"})
|
||||
|
||||
|
||||
# Max size for embedding images directly in the user message (20 MiB raw).
|
||||
_MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024
|
||||
|
||||
@@ -247,6 +248,8 @@ class _BaselineStreamState:
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
turn_cache_read_tokens: int = 0
|
||||
turn_cache_creation_tokens: int = 0
|
||||
cost_usd: float | None = None
|
||||
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
|
||||
session_messages: list[ChatMessage] = field(default_factory=list)
|
||||
@@ -294,6 +297,18 @@ async def _baseline_llm_caller(
|
||||
if chunk.usage:
|
||||
state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0
|
||||
state.turn_completion_tokens += chunk.usage.completion_tokens or 0
|
||||
# Extract cache token details when available (OpenAI /
|
||||
# OpenRouter include these in prompt_tokens_details).
|
||||
ptd = getattr(chunk.usage, "prompt_tokens_details", None)
|
||||
if ptd:
|
||||
state.turn_cache_read_tokens += (
|
||||
getattr(ptd, "cached_tokens", 0) or 0
|
||||
)
|
||||
# cache_creation_input_tokens is reported by some providers
|
||||
# (e.g. Anthropic native) but not standard OpenAI streaming.
|
||||
state.turn_cache_creation_tokens += (
|
||||
getattr(ptd, "cache_creation_input_tokens", 0) or 0
|
||||
)
|
||||
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if not delta:
|
||||
@@ -1190,16 +1205,22 @@ async def stream_chat_completion_baseline(
|
||||
state.turn_prompt_tokens,
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
# Persist token usage to session and record for rate limiting.
|
||||
# NOTE: OpenRouter folds cached tokens into prompt_tokens, so we
|
||||
# cannot break out cache_read/cache_creation weights. Users on the
|
||||
# baseline path may be slightly over-counted vs the SDK path.
|
||||
# When prompt_tokens_details.cached_tokens is reported, subtract
|
||||
# them from prompt_tokens to get the uncached count so the cost
|
||||
# breakdown stays accurate.
|
||||
uncached_prompt = state.turn_prompt_tokens
|
||||
if state.turn_cache_read_tokens > 0:
|
||||
uncached_prompt = max(
|
||||
0, state.turn_prompt_tokens - state.turn_cache_read_tokens
|
||||
)
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
prompt_tokens=uncached_prompt,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
cache_read_tokens=state.turn_cache_read_tokens,
|
||||
cache_creation_tokens=state.turn_cache_creation_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
cost_usd=state.cost_usd,
|
||||
model=active_model,
|
||||
@@ -1269,10 +1290,13 @@ async def stream_chat_completion_baseline(
|
||||
# On GeneratorExit the client is already gone, so unreachable yields
|
||||
# are harmless; on normal completion they reach the SSE stream.
|
||||
if state.turn_prompt_tokens > 0 or state.turn_completion_tokens > 0:
|
||||
# Report uncached prompt tokens to match what was billed — cached tokens
|
||||
# are excluded so the frontend display is consistent with cost_usd.
|
||||
billed_prompt = max(0, state.turn_prompt_tokens - state.turn_cache_read_tokens)
|
||||
yield StreamUsage(
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
prompt_tokens=billed_prompt,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
total_tokens=state.turn_prompt_tokens + state.turn_completion_tokens,
|
||||
total_tokens=billed_prompt + state.turn_completion_tokens,
|
||||
)
|
||||
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -769,3 +769,244 @@ class TestBaselineCostExtraction:
|
||||
|
||||
# response was never assigned so cost extraction must not raise
|
||||
assert state.cost_usd is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_when_header_missing(self):
|
||||
"""cost_usd remains None when x-total-cost is absent."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="anthropic/claude-sonnet-4")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {} # no x-total-cost
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.usage = MagicMock()
|
||||
mock_chunk.usage.prompt_tokens = 1000
|
||||
mock_chunk.usage.completion_tokens = 500
|
||||
mock_chunk.usage.prompt_tokens_details = None
|
||||
mock_chunk.choices = []
|
||||
|
||||
async def chunk_aiter():
|
||||
yield mock_chunk
|
||||
|
||||
mock_stream.__aiter__ = lambda self: chunk_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_extracted_from_usage_details(self):
|
||||
"""cache tokens are extracted from prompt_tokens_details.cached_tokens."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="openai/gpt-4o")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": "0.01"}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
# Create a chunk with prompt_tokens_details
|
||||
mock_ptd = MagicMock()
|
||||
mock_ptd.cached_tokens = 800
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.usage = MagicMock()
|
||||
mock_chunk.usage.prompt_tokens = 1000
|
||||
mock_chunk.usage.completion_tokens = 200
|
||||
mock_chunk.usage.prompt_tokens_details = mock_ptd
|
||||
mock_chunk.choices = []
|
||||
|
||||
async def chunk_aiter():
|
||||
yield mock_chunk
|
||||
|
||||
mock_stream.__aiter__ = lambda self: chunk_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.turn_cache_read_tokens == 800
|
||||
assert state.turn_prompt_tokens == 1000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_creation_tokens_extracted_from_usage_details(self):
|
||||
"""cache_creation_tokens are extracted from prompt_tokens_details."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="openai/gpt-4o")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": "0.01"}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
mock_ptd = MagicMock()
|
||||
mock_ptd.cached_tokens = 0
|
||||
mock_ptd.cache_creation_input_tokens = 500
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.usage = MagicMock()
|
||||
mock_chunk.usage.prompt_tokens = 1000
|
||||
mock_chunk.usage.completion_tokens = 200
|
||||
mock_chunk.usage.prompt_tokens_details = mock_ptd
|
||||
mock_chunk.choices = []
|
||||
|
||||
async def chunk_aiter():
|
||||
yield mock_chunk
|
||||
|
||||
mock_stream.__aiter__ = lambda self: chunk_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.turn_cache_creation_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_accumulators_track_across_multiple_calls(self):
|
||||
"""Token accumulators grow correctly across multiple _baseline_llm_caller calls."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="anthropic/claude-sonnet-4")
|
||||
|
||||
def make_stream(prompt_tokens: int, completion_tokens: int):
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {} # no x-total-cost
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.usage = MagicMock()
|
||||
mock_chunk.usage.prompt_tokens = prompt_tokens
|
||||
mock_chunk.usage.completion_tokens = completion_tokens
|
||||
mock_chunk.usage.prompt_tokens_details = None
|
||||
mock_chunk.choices = []
|
||||
|
||||
async def chunk_aiter():
|
||||
yield mock_chunk
|
||||
|
||||
mock_stream.__aiter__ = lambda self: chunk_aiter()
|
||||
return mock_stream
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=[
|
||||
make_stream(1000, 200),
|
||||
make_stream(1100, 300),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "follow up"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
# No x-total-cost header and empty pricing table -- cost_usd remains None
|
||||
assert state.cost_usd is None
|
||||
# Accumulators hold all tokens across both turns
|
||||
assert state.turn_prompt_tokens == 2100
|
||||
assert state.turn_completion_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_remains_none_when_header_missing(self):
|
||||
"""cost_usd stays None when x-total-cost header is absent.
|
||||
|
||||
Token counts are still tracked; persist_and_record_usage handles
|
||||
the None cost by falling back to tracking_type='tokens'.
|
||||
"""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="anthropic/claude-sonnet-4")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {} # no x-total-cost
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.usage = MagicMock()
|
||||
mock_chunk.usage.prompt_tokens = 1000
|
||||
mock_chunk.usage.completion_tokens = 500
|
||||
mock_chunk.usage.prompt_tokens_details = None
|
||||
mock_chunk.choices = []
|
||||
|
||||
async def chunk_aiter():
|
||||
yield mock_chunk
|
||||
|
||||
mock_stream.__aiter__ = lambda self: chunk_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd is None
|
||||
assert state.turn_prompt_tokens == 1000
|
||||
assert state.turn_completion_tokens == 500
|
||||
|
||||
@@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
# Per-request model tier set by the frontend model toggle.
|
||||
# 'standard' uses the global config default (currently Sonnet).
|
||||
# 'advanced' forces the highest-capability model (currently Opus).
|
||||
# None means no preference — falls through to LD per-user targeting, then config.
|
||||
# Using tier names instead of model names keeps the contract model-agnostic.
|
||||
CopilotLlmModel = Literal["standard", "advanced"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -163,12 +170,12 @@ class ChatConfig(BaseSettings):
|
||||
"CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.",
|
||||
)
|
||||
claude_agent_max_budget_usd: float = Field(
|
||||
default=15.0,
|
||||
default=10.0,
|
||||
ge=0.01,
|
||||
le=1000.0,
|
||||
description="Maximum spend in USD per SDK query. The CLI attempts "
|
||||
"to wrap up gracefully when this budget is reached. "
|
||||
"Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
|
||||
"Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
|
||||
"Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.",
|
||||
)
|
||||
claude_agent_max_thinking_tokens: int = Field(
|
||||
@@ -197,6 +204,15 @@ class ChatConfig(BaseSettings):
|
||||
description="Maximum number of retries for transient API errors "
|
||||
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
|
||||
)
|
||||
claude_agent_cross_user_prompt_cache: bool = Field(
|
||||
default=True,
|
||||
description="Enable cross-user prompt caching via SystemPromptPreset. "
|
||||
"The Claude Code default prompt becomes a cacheable prefix shared "
|
||||
"across all users, and our custom prompt is appended after it. "
|
||||
"Dynamic sections (working dir, git status, auto-memory) are excluded "
|
||||
"from the prefix. Set to False to fall back to passing the system "
|
||||
"prompt as a raw string.",
|
||||
)
|
||||
claude_agent_cli_path: str | None = Field(
|
||||
default=None,
|
||||
description="Optional explicit path to a Claude Code CLI binary. "
|
||||
|
||||
@@ -351,6 +351,7 @@ class CoPilotProcessor:
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
model=entry.model,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.config import CopilotLlmModel, CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
model: CopilotLlmModel | None = None
|
||||
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -180,6 +183,7 @@ async def enqueue_copilot_turn(
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -192,6 +196,7 @@ async def enqueue_copilot_turn(
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
model: Per-request model tier ('standard' or 'advanced'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -204,6 +209,7 @@ async def enqueue_copilot_turn(
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
model=model,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Per-request idempotency lock for the /stream endpoint.
|
||||
|
||||
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
|
||||
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
|
||||
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
|
||||
``None`` when the key already exists (duplicate request).
|
||||
2. ``release()`` — deletes the key. Must be called on turn completion or turn
|
||||
error so the next legitimate send is never blocked.
|
||||
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
|
||||
the backend turn is still running, and releasing would reopen the duplicate
|
||||
window for infra-level retries. The 30 s TTL is the safety net.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_KEY_PREFIX = "chat:msg_dedup"
|
||||
_TTL_SECONDS = 30
|
||||
|
||||
|
||||
class _DedupLock:
|
||||
def __init__(self, key: str, redis) -> None:
|
||||
self._key = key
|
||||
self._redis = redis
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Best-effort key deletion. The TTL handles failures silently."""
|
||||
try:
|
||||
await self._redis.delete(self._key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def acquire_dedup_lock(
|
||||
session_id: str,
|
||||
message: str | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> _DedupLock | None:
|
||||
"""Acquire the idempotency lock for this (session, message, files) tuple.
|
||||
|
||||
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
|
||||
Returns ``None`` when a duplicate is detected (lock already held).
|
||||
Returns ``None`` when there is nothing to deduplicate (no message, no files).
|
||||
"""
|
||||
if not message and not file_ids:
|
||||
return None
|
||||
|
||||
sorted_ids = ":".join(sorted(file_ids or []))
|
||||
content_hash = hashlib.sha256(
|
||||
f"{session_id}:{message or ''}:{sorted_ids}".encode()
|
||||
).hexdigest()[:16]
|
||||
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
|
||||
|
||||
redis = await get_redis_async()
|
||||
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"[STREAM] Duplicate user message blocked for session {session_id}, "
|
||||
f"hash={content_hash} — returning empty SSE",
|
||||
)
|
||||
return None
|
||||
|
||||
return _DedupLock(key, redis)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Unit tests for backend.copilot.message_dedup."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
|
||||
|
||||
|
||||
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_when_no_message_no_files(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Nothing to deduplicate — no Redis call made, None returned."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
result = await acquire_dedup_lock("sess-1", None, None)
|
||||
assert result is None
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_lock_on_first_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""First request acquires the lock and returns a _DedupLock."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
mock_redis.set.assert_called_once()
|
||||
key_arg = mock_redis.set.call_args.args[0]
|
||||
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_on_duplicate(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Duplicate request (NX fails) returns None to signal the caller."""
|
||||
_patch_redis(mocker, set_returns=None)
|
||||
result = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_key_stable_across_file_order(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs are sorted before hashing so order doesn't affect the key."""
|
||||
mock_redis_1 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
|
||||
key_ab = mock_redis_1.set.call_args.args[0]
|
||||
|
||||
mock_redis_2 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
|
||||
key_ba = mock_redis_2.set.call_args.args[0]
|
||||
|
||||
assert key_ab == key_ba
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_deletes_key(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() calls Redis delete exactly once."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_swallows_redis_error(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() must not raise even when Redis delete fails."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release() # must not raise
|
||||
mock_redis.delete.assert_called_once()
|
||||
@@ -302,6 +302,7 @@ async def record_token_usage(
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
@@ -315,12 +316,17 @@ async def record_token_usage(
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
``model_cost_multiplier`` scales the final weighted total to reflect
|
||||
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
|
||||
so that Opus turns deplete the rate limit faster, proportional to cost.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
@@ -332,7 +338,9 @@ async def record_token_usage(
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
total = round(
|
||||
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
|
||||
)
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
@@ -340,11 +348,12 @@ async def record_token_usage(
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
model_cost_multiplier,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
|
||||
@@ -34,9 +34,13 @@ Steps:
|
||||
always inspect the current graph first so you know exactly what to change.
|
||||
Avoid using `include_graph=true` with broad keyword searches, as fetching
|
||||
multiple graphs at once is expensive and consumes LLM context budget.
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true, for_agent_generation=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
and full input/output schemas. The `for_agent_generation=true` flag is
|
||||
required to surface graph-only blocks such as AgentInputBlock,
|
||||
AgentDropdownInputBlock, AgentOutputBlock, OrchestratorBlock,
|
||||
and WebhookBlock and MCPToolBlock. (When running MCP tools interactively
|
||||
in CoPilot outside agent generation, use `run_mcp_tool` instead.)
|
||||
3. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
|
||||
@@ -177,6 +181,12 @@ To compose agents using other agents as sub-agents:
|
||||
|
||||
### Using MCP Tools (MCPToolBlock)
|
||||
|
||||
> **Agent graph vs CoPilot direct execution**: This section covers embedding MCP
|
||||
> tools as persistent nodes in an agent graph. When running MCP tools directly in
|
||||
> CoPilot (outside agent generation), use `run_mcp_tool` instead — it handles
|
||||
> server discovery and authentication interactively. Use `MCPToolBlock` here only
|
||||
> when the user wants the MCP call baked into a reusable agent graph.
|
||||
|
||||
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
1. The user must specify which MCP server URL and tool name they want
|
||||
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||
|
||||
@@ -0,0 +1,555 @@
|
||||
"""Tests for context fallback paths introduced in fix/copilot-transcript-resume-gate.
|
||||
|
||||
Scenario table
|
||||
==============
|
||||
|
||||
| # | use_resume | transcript_msg_count | gap | target_tokens | Expected output |
|
||||
|---|------------|----------------------|---------|---------------|--------------------------------------------|
|
||||
| A | True | covers all | empty | None | bare message (--resume has full context) |
|
||||
| B | True | stale | 2 msgs | None | gap context prepended |
|
||||
| C | True | stale | 2 msgs | 50_000 | gap compressed to budget, prepended |
|
||||
| D | False | 0 | N/A | None | full session compressed, prepended |
|
||||
| E | False | 0 | N/A | 50_000 | full session compressed to budget |
|
||||
| F | False | 2 (partial) | 2 msgs | None | full session compressed (not just gap; |
|
||||
| | | | | | CLI has zero context without --resume) |
|
||||
| G | False | 2 (partial) | 2 msgs | 50_000 | full session compressed to budget |
|
||||
| H | False | covers all | empty | None | full session compressed |
|
||||
| | | | | | (NOT bare message — the bug that was fixed)|
|
||||
| I | False | covers all | empty | 50_000 | full session compressed to tight budget |
|
||||
| J | False | 2 (partial) | n/a | None | exactly ONE compression call (full prior) |
|
||||
|
||||
Compression unit tests
|
||||
=======================
|
||||
|
||||
| # | Input | target_tokens | Expected |
|
||||
|---|----------------------|---------------|-----------------------------------------------|
|
||||
| K | [] | None | ([], False) — empty guard |
|
||||
| L | [1 msg] | None | ([msg], False) — single-msg guard |
|
||||
| M | [2+ msgs] | None | target_tokens=None forwarded to _run_compression |
|
||||
| N | [2+ msgs] | 30_000 | target_tokens=30_000 forwarded |
|
||||
| O | [2+ msgs], run fails | None | returns originals, False |
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import _build_query_message, _compress_messages
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _msgs(*pairs: tuple[str, str]) -> list[ChatMessage]:
|
||||
return [ChatMessage(role=r, content=c) for r, c in pairs]
|
||||
|
||||
|
||||
def _passthrough_compress(target_tokens=None):
|
||||
"""Return a mock that passes messages through and records its call args."""
|
||||
calls: list[tuple[list, int | None]] = []
|
||||
|
||||
async def _mock(msgs, tok=None):
|
||||
calls.append((msgs, tok))
|
||||
return msgs, False
|
||||
|
||||
_mock.calls = calls # type: ignore[attr-defined]
|
||||
return _mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_query_message — scenario A–J
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildQueryMessageResume:
|
||||
"""use_resume=True paths (--resume supplies history; only inject gap if stale)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_a_transcript_current_returns_bare_message(self):
|
||||
"""Scenario A: --resume covers full context → no prefix injected."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
result, compacted = await _build_query_message(
|
||||
"q2", session, use_resume=True, transcript_msg_count=2, session_id="s"
|
||||
)
|
||||
assert result == "q2"
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_b_stale_transcript_injects_gap(self, monkeypatch):
|
||||
"""Scenario B: stale transcript → gap context prepended."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, compacted = await _build_query_message(
|
||||
"q3", session, use_resume=True, transcript_msg_count=2, session_id="s"
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "q2" in result
|
||||
assert "a2" in result
|
||||
assert "Now, the user says:\nq3" in result
|
||||
# q1/a1 are covered by the transcript — must NOT appear in gap context
|
||||
assert "q1" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_c_stale_transcript_passes_target_tokens(self, monkeypatch):
|
||||
"""Scenario C: target_tokens is forwarded to _compress_messages for the gap."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
target_tokens=50_000,
|
||||
)
|
||||
assert captured == [50_000]
|
||||
|
||||
|
||||
class TestBuildQueryMessageNoResumeNoTranscript:
|
||||
"""use_resume=False, transcript_msg_count=0 — full session compressed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_d_full_session_compressed(self, monkeypatch):
|
||||
"""Scenario D: no resume, no transcript → compress all prior messages."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, compacted = await _build_query_message(
|
||||
"q2", session, use_resume=False, transcript_msg_count=0, session_id="s"
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "q1" in result
|
||||
assert "a1" in result
|
||||
assert "Now, the user says:\nq2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_e_passes_target_tokens_to_compression(self, monkeypatch):
|
||||
"""Scenario E: target_tokens forwarded to _compress_messages."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q2",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="s",
|
||||
target_tokens=15_000,
|
||||
)
|
||||
assert captured == [15_000]
|
||||
|
||||
|
||||
class TestBuildQueryMessageNoResumeWithTranscript:
|
||||
"""use_resume=False, transcript_msg_count > 0 — gap or full-session fallback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_f_no_resume_always_injects_full_session(self, monkeypatch):
|
||||
"""Scenario F: use_resume=False with transcript_msg_count > 0 still injects
|
||||
the FULL prior session — not just the gap since the transcript end.
|
||||
|
||||
When there is no --resume the CLI starts with zero context, so injecting
|
||||
only the post-transcript gap would silently drop all transcript-covered
|
||||
history. The correct fix is to always compress the full session.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"), # transcript_msg_count=2 covers these
|
||||
("assistant", "a1"),
|
||||
("user", "q2"), # post-transcript gap starts here
|
||||
("assistant", "a2"),
|
||||
("user", "q3"), # current message
|
||||
)
|
||||
)
|
||||
compressed_msgs: list[list] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
compressed_msgs.append(list(msgs))
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, _ = await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2, # transcript covers q1/a1 but no --resume
|
||||
session_id="s",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
# Full session must be injected — transcript-covered turns ARE included
|
||||
assert "q1" in result
|
||||
assert "a1" in result
|
||||
assert "q2" in result
|
||||
assert "a2" in result
|
||||
assert "Now, the user says:\nq3" in result
|
||||
# Compressed exactly once with all 4 prior messages
|
||||
assert len(compressed_msgs) == 1
|
||||
assert len(compressed_msgs[0]) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_g_no_resume_passes_target_tokens(self, monkeypatch):
|
||||
"""Scenario G: target_tokens forwarded when use_resume=False + transcript_msg_count > 0."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
target_tokens=50_000,
|
||||
)
|
||||
assert captured == [50_000]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_h_no_resume_transcript_current_injects_full_session(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Scenario H: the bug that was fixed.
|
||||
|
||||
Old code path: use_resume=False, transcript_msg_count covers all prior
|
||||
messages → gap sub-path: gap = [] → ``return current_message, False``
|
||||
→ model received ZERO context (bare message only).
|
||||
|
||||
New code path: use_resume=False always compresses the full prior session
|
||||
regardless of transcript_msg_count — model always gets context.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, _ = await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=4, # covers ALL prior → old code returned bare msg
|
||||
session_id="s",
|
||||
)
|
||||
# NEW: must inject full session, NOT return bare message
|
||||
assert result != "q3"
|
||||
assert "<conversation_history>" in result
|
||||
assert "q1" in result
|
||||
assert "Now, the user says:\nq3" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_i_no_resume_target_tokens_forwarded_any_transcript_count(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Scenario I: target_tokens forwarded even when transcript_msg_count covers all."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q2",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
target_tokens=15_000,
|
||||
)
|
||||
assert 15_000 in captured
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_j_no_resume_single_compression_call(self, monkeypatch):
|
||||
"""Scenario J: use_resume=False always makes exactly ONE compression call
|
||||
(the full session), regardless of transcript coverage.
|
||||
|
||||
This verifies there is no two-step gap+fallback pattern for no-resume —
|
||||
compression is called once with the full prior session.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
call_count = 0
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
)
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compress_messages — unit tests K–O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompressMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_k_empty_list_returns_empty(self):
|
||||
"""Scenario K: empty input → short-circuit, no compression."""
|
||||
result, compacted = await _compress_messages([])
|
||||
assert result == []
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_l_single_message_returns_as_is(self):
|
||||
"""Scenario L: single message → short-circuit (< 2 guard)."""
|
||||
msg = ChatMessage(role="user", content="hello")
|
||||
result, compacted = await _compress_messages([msg])
|
||||
assert result == [msg]
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_m_target_tokens_none_forwarded(self):
|
||||
"""Scenario M: target_tokens=None forwarded to _run_compression."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
fake_result = CompressResult(
|
||||
messages=[
|
||||
{"role": "user", "content": "q"},
|
||||
{"role": "assistant", "content": "a"},
|
||||
],
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
original_token_count=10,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_result,
|
||||
) as mock_run:
|
||||
await _compress_messages(msgs, target_tokens=None)
|
||||
|
||||
mock_run.assert_awaited_once()
|
||||
_, kwargs = mock_run.call_args
|
||||
assert kwargs.get("target_tokens") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_n_explicit_target_tokens_forwarded(self):
|
||||
"""Scenario N: explicit target_tokens forwarded to _run_compression."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
fake_result = CompressResult(
|
||||
messages=[{"role": "user", "content": "summary"}],
|
||||
token_count=5,
|
||||
was_compacted=True,
|
||||
original_token_count=50,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_result,
|
||||
) as mock_run:
|
||||
result, compacted = await _compress_messages(msgs, target_tokens=30_000)
|
||||
|
||||
mock_run.assert_awaited_once()
|
||||
_, kwargs = mock_run.call_args
|
||||
assert kwargs.get("target_tokens") == 30_000
|
||||
assert compacted is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_o_run_compression_exception_returns_originals(self):
|
||||
"""Scenario O: _run_compression raises → return original messages, False."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("compression timeout"),
|
||||
):
|
||||
result, compacted = await _compress_messages(msgs)
|
||||
|
||||
assert result == msgs
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_messages_filtered_before_compression(self):
|
||||
"""filter_compaction_messages is applied before _run_compression is called."""
|
||||
# A compaction message is one with role=assistant and specific content pattern.
|
||||
# We verify that only real messages reach _run_compression.
|
||||
from backend.copilot.sdk.service import filter_compaction_messages
|
||||
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
# filter_compaction_messages should not remove these plain messages
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == len(msgs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# target_tokens threading — _retry_target_tokens values match expectations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetryTargetTokens:
|
||||
def test_first_retry_uses_first_slot(self):
|
||||
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
|
||||
|
||||
assert _RETRY_TARGET_TOKENS[0] == 50_000
|
||||
|
||||
def test_second_retry_uses_second_slot(self):
|
||||
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
|
||||
|
||||
assert _RETRY_TARGET_TOKENS[1] == 15_000
|
||||
|
||||
def test_second_slot_smaller_than_first(self):
|
||||
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
|
||||
|
||||
assert _RETRY_TARGET_TOKENS[1] < _RETRY_TARGET_TOKENS[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-message session edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleMessageSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_resume_single_message_returns_bare(self):
|
||||
"""First turn (1 message): no prior history to inject."""
|
||||
session = _make_session([ChatMessage(role="user", content="hello")])
|
||||
result, compacted = await _build_query_message(
|
||||
"hello", session, use_resume=False, transcript_msg_count=0, session_id="s"
|
||||
)
|
||||
assert result == "hello"
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_single_message_returns_bare(self):
|
||||
"""First turn with resume flag: transcript is empty so no gap."""
|
||||
session = _make_session([ChatMessage(role="user", content="hello")])
|
||||
result, compacted = await _build_query_message(
|
||||
"hello", session, use_resume=True, transcript_msg_count=0, session_id="s"
|
||||
)
|
||||
assert result == "hello"
|
||||
assert compacted is False
|
||||
@@ -0,0 +1,326 @@
|
||||
"""Tests for transcript context coverage when switching between fast and SDK modes.
|
||||
|
||||
When a user switches modes mid-session the transcript must bridge the gap so
|
||||
neither the baseline nor the SDK service loses context from turns produced by
|
||||
the other mode.
|
||||
|
||||
Cross-mode transcript flow
|
||||
==========================
|
||||
|
||||
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
|
||||
mode) read and write the same JSONL transcript store via
|
||||
``backend.copilot.transcript.upload_transcript`` /
|
||||
``download_transcript``.
|
||||
|
||||
Fast → SDK switch
|
||||
-----------------
|
||||
On the first SDK turn after N baseline turns:
|
||||
• ``use_resume=False`` — no CLI session exists from baseline mode.
|
||||
• ``transcript_msg_count > 0`` — the baseline transcript is downloaded and
|
||||
validated successfully.
|
||||
• ``_build_query_message`` must inject the FULL prior session (not just a
|
||||
"gap" since the transcript end) because the CLI has zero context without
|
||||
``--resume``.
|
||||
• After our fix, ``session_id`` IS set, so the CLI writes a session file
|
||||
on this turn → ``--resume`` works on T2+.
|
||||
|
||||
SDK → Fast switch
|
||||
-----------------
|
||||
On the first baseline turn after N SDK turns:
|
||||
• The baseline service downloads the SDK-written transcript.
|
||||
• ``_load_prior_transcript`` loads and validates it normally — the JSONL
|
||||
format is identical regardless of which mode wrote it.
|
||||
• ``transcript_covers_prefix=True`` → baseline sends ONLY new messages in
|
||||
its LLM payload (no double-counting of SDK history).
|
||||
|
||||
Scenario table (SDK _build_query_message)
|
||||
==========================================
|
||||
|
||||
| # | Scenario | use_resume | tmc | Expected query message |
|
||||
|---|--------------------------------|------------|-----|---------------------------------|
|
||||
| P | Fast→SDK T1 | False | 4 | full session injected |
|
||||
| Q | Fast→SDK T2+ (after fix) | True | 6 | bare message only (--resume ok) |
|
||||
| R | Fast→SDK T1, single baseline | False | 2 | full session injected |
|
||||
| S | SDK→Fast (baseline loads ok) | N/A | N/A | transcript covers prefix=True |
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import _build_query_message
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _msgs(*pairs: tuple[str, str]) -> list[ChatMessage]:
|
||||
return [ChatMessage(role=r, content=c) for r, c in pairs]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario P — Fast → SDK T1: full session injected from baseline transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFastToSdkModeSwitch:
|
||||
"""First SDK turn after N baseline (fast) turns.
|
||||
|
||||
The baseline transcript exists (has been uploaded by fast mode), but
|
||||
there is no CLI session file. ``_build_query_message`` must inject
|
||||
the complete prior session so the model has full context.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_p_full_session_injected_on_mode_switch_t1(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Scenario P: fast→SDK T1 injects all baseline turns into the query."""
|
||||
# Simulate 4 baseline messages (2 turns) followed by the first SDK turn.
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "baseline-q1"),
|
||||
("assistant", "baseline-a1"),
|
||||
("user", "baseline-q2"),
|
||||
("assistant", "baseline-a2"),
|
||||
("user", "sdk-q1"), # current SDK turn
|
||||
)
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
# transcript_msg_count=4: baseline uploaded a transcript covering all
|
||||
# 4 prior messages, but use_resume=False (no CLI session from baseline).
|
||||
result, compacted = await _build_query_message(
|
||||
"sdk-q1",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=4,
|
||||
session_id="s",
|
||||
)
|
||||
|
||||
# All baseline turns must appear — none of them can be silently dropped.
|
||||
assert "<conversation_history>" in result
|
||||
assert "baseline-q1" in result
|
||||
assert "baseline-a1" in result
|
||||
assert "baseline-q2" in result
|
||||
assert "baseline-a2" in result
|
||||
assert "Now, the user says:\nsdk-q1" in result
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_r_single_baseline_turn_injected(self, monkeypatch):
|
||||
"""Scenario R: even a single baseline turn is captured on mode-switch T1."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "baseline-q1"),
|
||||
("assistant", "baseline-a1"),
|
||||
("user", "sdk-q1"),
|
||||
)
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, _ = await _build_query_message(
|
||||
"sdk-q1",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
)
|
||||
|
||||
assert "<conversation_history>" in result
|
||||
assert "baseline-q1" in result
|
||||
assert "baseline-a1" in result
|
||||
assert "Now, the user says:\nsdk-q1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_q_sdk_t2_uses_resume_after_fix(self):
|
||||
"""Scenario Q: SDK T2+ uses --resume after mode-switch T1 set session_id.
|
||||
|
||||
With the mode-switch fix, T1 sets session_id → CLI writes session file →
|
||||
T2 restores the session → use_resume=True. _build_query_message must
|
||||
return the bare message (--resume supplies context via native session).
|
||||
"""
|
||||
# T2: 4 baseline turns + 1 SDK turn already recorded.
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "baseline-q1"),
|
||||
("assistant", "baseline-a1"),
|
||||
("user", "baseline-q2"),
|
||||
("assistant", "baseline-a2"),
|
||||
("user", "sdk-q1"),
|
||||
("assistant", "sdk-a1"),
|
||||
("user", "sdk-q2"), # current SDK T2 message
|
||||
)
|
||||
)
|
||||
|
||||
# transcript_msg_count=6 covers all prior messages → no gap.
|
||||
result, compacted = await _build_query_message(
|
||||
"sdk-q2",
|
||||
session,
|
||||
use_resume=True, # T2: --resume works after T1 set session_id
|
||||
transcript_msg_count=6,
|
||||
session_id="s",
|
||||
)
|
||||
|
||||
# --resume has full context — bare message only.
|
||||
assert result == "sdk-q2"
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_switch_t1_compresses_all_baseline_turns(self, monkeypatch):
|
||||
"""_compress_messages is called with ALL prior baseline messages.
|
||||
|
||||
There is exactly one compression call containing all 4 baseline messages
|
||||
— not just the 2 post-transcript-end messages.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "baseline-q1"),
|
||||
("assistant", "baseline-a1"),
|
||||
("user", "baseline-q2"),
|
||||
("assistant", "baseline-a2"),
|
||||
("user", "sdk-q1"),
|
||||
)
|
||||
)
|
||||
compressed_batches: list[list] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
compressed_batches.append(list(msgs))
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"sdk-q1",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=4,
|
||||
session_id="s",
|
||||
)
|
||||
|
||||
# Exactly one compression call, with all 4 prior messages.
|
||||
assert len(compressed_batches) == 1
|
||||
assert len(compressed_batches[0]) == 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario S — SDK → Fast: baseline loads SDK-written transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSdkToFastModeSwitch:
|
||||
"""Fast mode turn after N SDK (extended_thinking) turns.
|
||||
|
||||
The transcript written by SDK mode uses the same JSONL format as the one
|
||||
written by baseline mode (both go through ``TranscriptBuilder``).
|
||||
``_load_prior_transcript`` must accept it and mark the prefix as covered.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_baseline_loads_sdk_transcript(self):
|
||||
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Build a minimal valid transcript as SDK mode would write it.
|
||||
# SDK uses append_user / append_assistant on TranscriptBuilder.
|
||||
builder_sdk = TranscriptBuilder()
|
||||
builder_sdk.append_user(content="sdk-question")
|
||||
builder_sdk.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "sdk-answer"}],
|
||||
model="claude-sonnet-4",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Baseline session now has those 2 SDK messages + 1 new baseline message.
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3, # 2 SDK + 1 new baseline
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Transcript is valid and covers the prefix.
|
||||
assert covers is True
|
||||
assert baseline_builder.entry_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
|
||||
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
|
||||
|
||||
If SDK mode produced more turns than the transcript captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale transcript
|
||||
to avoid injecting an incomplete history.
|
||||
"""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
builder_sdk = TranscriptBuilder()
|
||||
builder_sdk.append_user(content="sdk-question")
|
||||
builder_sdk.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "sdk-answer"}],
|
||||
model="claude-sonnet-4",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Transcript covers only 2 messages but session has 10 (many SDK turns).
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Stale transcript must be rejected.
|
||||
assert covers is False
|
||||
assert baseline_builder.is_empty
|
||||
@@ -207,7 +207,7 @@ class TestConfigDefaults:
|
||||
|
||||
def test_max_budget_usd_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_budget_usd == 15.0
|
||||
assert cfg.claude_agent_max_budget_usd == 10.0
|
||||
|
||||
def test_max_thinking_tokens_default(self):
|
||||
cfg = _make_config()
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import (
|
||||
_BARE_MESSAGE_TOKEN_FLOOR,
|
||||
_build_query_message,
|
||||
_format_conversation_context,
|
||||
)
|
||||
@@ -130,6 +131,34 @@ async def test_build_query_resume_up_to_date():
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_misaligned_watermark():
|
||||
"""With --resume and watermark pointing at a user message, skip gap."""
|
||||
# Simulates a deleted message shifting DB positions so the watermark
|
||||
# lands on a user turn instead of the expected assistant turn.
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="turn 1"),
|
||||
ChatMessage(role="assistant", content="reply 1"),
|
||||
ChatMessage(
|
||||
role="user", content="turn 2"
|
||||
), # ← watermark points here (role=user)
|
||||
ChatMessage(role="assistant", content="reply 2"),
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=3, # prior[2].role == "user" — misaligned
|
||||
session_id="test-session",
|
||||
)
|
||||
# Misaligned watermark → skip gap, return bare message
|
||||
assert result == "turn 3"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_stale_transcript():
|
||||
"""With --resume and stale transcript, gap context is prepended."""
|
||||
@@ -204,7 +233,7 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
)
|
||||
|
||||
# Mock _compress_messages to return the messages as-is
|
||||
async def _mock_compress(msgs):
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -237,7 +266,7 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs):
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, True # Simulate actual compaction
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -253,3 +282,85 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
session_id="test-session",
|
||||
)
|
||||
assert was_compacted is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_at_token_floor():
|
||||
"""When target_tokens is at or below the floor, return bare message.
|
||||
|
||||
This is the final escape hatch: if the retry budget is exhausted and
|
||||
even the most aggressive compression might not fit, skip history
|
||||
injection entirely so the user always gets a response.
|
||||
"""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old question"),
|
||||
ChatMessage(role="assistant", content="old answer"),
|
||||
ChatMessage(role="user", content="new question"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR,
|
||||
)
|
||||
# At the floor threshold, no history is injected
|
||||
assert result == "new question"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_below_token_floor():
|
||||
"""target_tokens strictly below floor also returns bare message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old"),
|
||||
ChatMessage(role="assistant", content="reply"),
|
||||
ChatMessage(role="user", content="new"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR - 1,
|
||||
)
|
||||
assert result == "new"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_above_token_floor_compresses(monkeypatch):
|
||||
"""target_tokens just above the floor still triggers compression."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old"),
|
||||
ChatMessage(role="assistant", content="reply"),
|
||||
ChatMessage(role="user", content="new"),
|
||||
]
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR + 1,
|
||||
)
|
||||
# Above the floor → history is injected (not the bare message)
|
||||
assert "<conversation_history>" in result
|
||||
assert "Now, the user says:\nnew" in result
|
||||
|
||||
@@ -7,6 +7,7 @@ tests will catch it immediately.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -90,6 +91,39 @@ def test_agent_options_accepts_required_fields():
|
||||
assert opts.cwd == "/tmp"
|
||||
|
||||
|
||||
def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections():
|
||||
"""Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces.
|
||||
|
||||
The production code always includes ``exclude_dynamic_sections=True`` in the preset
|
||||
dict. This compat test mirrors that exact shape so any SDK version that starts
|
||||
rejecting unknown keys will be caught here rather than at runtime.
|
||||
"""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
from claude_agent_sdk.types import SystemPromptPreset
|
||||
|
||||
from .service import _build_system_prompt_value
|
||||
|
||||
# Call the production helper directly so this test is tied to the real
|
||||
# dict shape rather than a hand-rolled copy.
|
||||
preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True)
|
||||
assert isinstance(
|
||||
preset, dict
|
||||
), "_build_system_prompt_value must return a dict when caching is on"
|
||||
|
||||
sdk_preset = cast(SystemPromptPreset, preset)
|
||||
opts = ClaudeAgentOptions(system_prompt=sdk_preset)
|
||||
assert opts.system_prompt == sdk_preset
|
||||
|
||||
|
||||
def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off():
|
||||
"""When cross_user_cache=False (e.g. on --resume turns), the helper must return
|
||||
a plain string so the preset+resume crash is avoided."""
|
||||
from .service import _build_system_prompt_value
|
||||
|
||||
result = _build_system_prompt_value("my prompt", cross_user_cache=False)
|
||||
assert result == "my prompt", "Must return the raw string, not a preset dict"
|
||||
|
||||
|
||||
def test_agent_options_accepts_all_our_fields():
|
||||
"""Comprehensive check of every field we use in service.py."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||
|
||||
# isort: skip_file
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
@@ -17,7 +18,7 @@ from dataclasses import field as dataclass_field
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.permissions import CopilotPermissions
|
||||
from ..permissions import CopilotPermissions
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
@@ -29,16 +30,17 @@ from claude_agent_sdk import (
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from claude_agent_sdk.types import SystemPromptPreset
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.copilot.thinking_stripper import ThinkingStripper
|
||||
from backend.copilot.transcript import (
|
||||
from ..context import get_workspace_manager
|
||||
from ..permissions import apply_tool_permissions
|
||||
from ..rate_limit import get_user_tier
|
||||
from ..thinking_stripper import ThinkingStripper
|
||||
from ..transcript import (
|
||||
_run_compression,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
@@ -49,13 +51,13 @@ from backend.copilot.transcript import (
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from ..transcript_builder import TranscriptBuilder
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig, CopilotMode
|
||||
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -131,6 +133,11 @@ _MAX_STREAM_ATTEMPTS = 3
|
||||
# self-correct. The limit is generous to allow recovery attempts.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 5
|
||||
|
||||
# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet
|
||||
# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus
|
||||
# turns deplete quota proportionally faster.
|
||||
_OPUS_COST_MULTIPLIER = 5.0
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"AutoPilot was unable to complete the tool call "
|
||||
@@ -260,6 +267,11 @@ class ReducedContext(NamedTuple):
|
||||
resume_file: str | None
|
||||
transcript_lost: bool
|
||||
tried_compaction: bool
|
||||
# Token budget for history compression on the DB-message fallback path.
|
||||
# None means "use model-aware default". Halved on each retry so
|
||||
# compress_context applies progressively more aggressive reduction
|
||||
# (LLM summarize → content truncate → middle-out delete → first/last trim).
|
||||
target_tokens: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -304,6 +316,10 @@ class _RetryState:
|
||||
adapter: SDKResponseAdapter
|
||||
transcript_builder: TranscriptBuilder
|
||||
usage: _TokenUsage
|
||||
# Token budget for history compression on retries (DB-message fallback path).
|
||||
# None = model-aware default. Halved each retry for progressively more
|
||||
# aggressive compression (LLM summarize → truncate → middle-out → trim).
|
||||
target_tokens: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -335,12 +351,34 @@ class _StreamContext:
|
||||
lock: AsyncClusterLock
|
||||
|
||||
|
||||
# Per-retry token budgets for the no-transcript (use_resume=False) path.
|
||||
# When there is no CLI native session to --resume, context is built from DB
|
||||
# messages via _format_conversation_context. For large sessions this text
|
||||
# can exceed the model context window; each retry halves the token budget so
|
||||
# compress_context applies progressively more aggressive reduction:
|
||||
# LLM summarize → content truncate → middle-out delete → first/last trim.
|
||||
# Index 0 = first retry, 1 = second retry; last value applies beyond that.
|
||||
_RETRY_TARGET_TOKENS: tuple[int, ...] = (50_000, 15_000)
|
||||
|
||||
# Below this token budget the model context is so tight that injecting any
|
||||
# conversation history would likely exceed the limit regardless of content.
|
||||
# _build_query_message returns the bare message when target_tokens falls to
|
||||
# or below this floor, giving the user a response instead of a hard error.
|
||||
_BARE_MESSAGE_TOKEN_FLOOR: int = 5_000
|
||||
|
||||
# Tight token budget for seeding the transcript builder on turns where no
|
||||
# CLI native session exists. Kept below _RETRY_TARGET_TOKENS[0] so the
|
||||
# seeded JSONL upload stays compact and future gap injections are small.
|
||||
_SEED_TARGET_TOKENS: int = 30_000
|
||||
|
||||
|
||||
async def _reduce_context(
|
||||
transcript_content: str,
|
||||
tried_compaction: bool,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str,
|
||||
attempt: int = 1,
|
||||
) -> ReducedContext:
|
||||
"""Prepare reduced context for a retry attempt.
|
||||
|
||||
@@ -348,9 +386,19 @@ async def _reduce_context(
|
||||
On subsequent retries (or if compaction fails), drops the transcript
|
||||
entirely so the query is rebuilt from DB messages only.
|
||||
|
||||
`transcript_lost` is True when the transcript was dropped (caller
|
||||
should set `skip_transcript_upload`).
|
||||
When no transcript is available (use_resume=False fallback path), returns
|
||||
a decreasing ``target_tokens`` budget so ``compress_context`` applies
|
||||
progressively more aggressive reduction (LLM summarize → content truncate
|
||||
→ middle-out delete → first/last trim). The budget applies in
|
||||
``_build_query_message`` and is halved on each retry.
|
||||
|
||||
``transcript_lost`` is True when the transcript was dropped (caller
|
||||
should set ``skip_transcript_upload``).
|
||||
"""
|
||||
# Token budget for the DB fallback on this attempt (no-transcript path).
|
||||
idx = max(0, attempt - 1)
|
||||
retry_target = _RETRY_TARGET_TOKENS[min(idx, len(_RETRY_TARGET_TOKENS) - 1)]
|
||||
|
||||
# First retry: try compacting our transcript builder state.
|
||||
# Note: the CLI native --resume file is not updated with the compacted
|
||||
# content (it would require emitting CLI-native JSONL format), so the
|
||||
@@ -374,9 +422,14 @@ async def _reduce_context(
|
||||
return ReducedContext(tb, False, None, False, True)
|
||||
logger.warning("%s Compaction failed, dropping transcript", log_prefix)
|
||||
|
||||
# Subsequent retry or compaction failed: drop transcript entirely
|
||||
logger.warning("%s Dropping transcript, rebuilding from DB messages", log_prefix)
|
||||
return ReducedContext(TranscriptBuilder(), False, None, True, True)
|
||||
# Subsequent retry or compaction failed: drop transcript entirely.
|
||||
# Return retry_target so the caller compresses DB messages to that budget.
|
||||
logger.warning(
|
||||
"%s Dropping transcript, rebuilding from DB messages" " (target_tokens=%d)",
|
||||
log_prefix,
|
||||
retry_target,
|
||||
)
|
||||
return ReducedContext(TranscriptBuilder(), False, None, True, True, retry_target)
|
||||
|
||||
|
||||
def _append_error_marker(
|
||||
@@ -627,6 +680,48 @@ def _resolve_fallback_model() -> str | None:
|
||||
return _normalize_model_name(raw)
|
||||
|
||||
|
||||
async def _resolve_model_and_multiplier(
|
||||
model: "CopilotLlmModel | None",
|
||||
session_id: str,
|
||||
) -> tuple[str | None, float]:
|
||||
"""Resolve the SDK model string and rate-limit cost multiplier for a turn.
|
||||
|
||||
Priority (highest first):
|
||||
1. Explicit per-request ``model`` tier from the frontend toggle.
|
||||
2. Global config default (``_resolve_sdk_model()``).
|
||||
|
||||
Returns a ``(sdk_model, cost_multiplier)`` pair.
|
||||
``sdk_model`` is ``None`` when the Claude Code subscription default applies.
|
||||
``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise.
|
||||
"""
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
if model == "advanced":
|
||||
sdk_model = _normalize_model_name("anthropic/claude-opus-4-6")
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: advanced (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model,
|
||||
)
|
||||
return sdk_model, _OPUS_COST_MULTIPLIER
|
||||
|
||||
if model == "standard":
|
||||
# Reset to config default — respects subscription mode (None = CLI default).
|
||||
sdk_model = _resolve_sdk_model()
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: standard (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model or "subscription-default",
|
||||
)
|
||||
return sdk_model, 1.0
|
||||
|
||||
# No per-request override; derive multiplier from final resolved model.
|
||||
cost_multiplier = (
|
||||
_OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0
|
||||
)
|
||||
return sdk_model, cost_multiplier
|
||||
|
||||
|
||||
_MAX_TRANSIENT_BACKOFF_SECONDS = 30
|
||||
|
||||
|
||||
@@ -705,6 +800,34 @@ def _is_fallback_stderr(line: str) -> bool:
|
||||
return "fallback model" in line.lower()
|
||||
|
||||
|
||||
def _build_system_prompt_value(
|
||||
system_prompt: str,
|
||||
cross_user_cache: bool,
|
||||
) -> str | SystemPromptPreset:
|
||||
"""Build the ``system_prompt`` argument for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *cross_user_cache* is enabled, returns a :class:`SystemPromptPreset`
|
||||
dict so the Claude Code default prompt becomes a cacheable prefix shared
|
||||
across all users; our custom *system_prompt* is appended after it.
|
||||
|
||||
When disabled (or if the SDK is too old to support ``SystemPromptPreset``),
|
||||
the raw *system_prompt* string is returned unchanged.
|
||||
|
||||
An empty *system_prompt* is accepted: the preset dict will have
|
||||
``append: ""`` which the SDK treats as no custom suffix.
|
||||
"""
|
||||
if cross_user_cache:
|
||||
logger.debug("Using SystemPromptPreset for cross-user prompt cache")
|
||||
return SystemPromptPreset(
|
||||
type="preset",
|
||||
preset="claude_code",
|
||||
append=system_prompt,
|
||||
exclude_dynamic_sections=True,
|
||||
)
|
||||
logger.debug("Cross-user prompt cache disabled, using raw string")
|
||||
return system_prompt
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
@@ -801,6 +924,7 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
|
||||
async def _compress_messages(
|
||||
messages: list[ChatMessage],
|
||||
target_tokens: int | None = None,
|
||||
) -> tuple[list[ChatMessage], bool]:
|
||||
"""Compress a list of messages if they exceed the token threshold.
|
||||
|
||||
@@ -809,6 +933,10 @@ async def _compress_messages(
|
||||
`_compress_messages` and `compact_transcript` share this helper so
|
||||
client acquisition and error handling are consistent.
|
||||
|
||||
``target_tokens`` sets a hard ceiling for the compressed output so
|
||||
callers can enforce a tighter budget on retries. When ``None``,
|
||||
``compress_context`` uses the model-aware default.
|
||||
|
||||
See also:
|
||||
`_run_compression` — shared compression with timeout guards.
|
||||
`compact_transcript` — compresses JSONL transcript entries.
|
||||
@@ -832,7 +960,9 @@ async def _compress_messages(
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await _run_compression(messages_dict, config.model, "[SDK]")
|
||||
result = await _run_compression(
|
||||
messages_dict, config.model, "[SDK]", target_tokens=target_tokens
|
||||
)
|
||||
except Exception as exc:
|
||||
# Guard against timeouts or unexpected errors in compression —
|
||||
# return the original messages so the caller can proceed without
|
||||
@@ -961,44 +1091,139 @@ async def _build_query_message(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
target_tokens: int | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
When ``use_resume=True``, the CLI has the full session via ``--resume``;
|
||||
only a gap-fill prefix is injected when the transcript is stale.
|
||||
|
||||
When ``use_resume=False``, the CLI starts a fresh session with no prior
|
||||
context, so the full prior session is always compressed and injected via
|
||||
``_format_conversation_context``. ``compress_context`` handles size
|
||||
reduction internally (LLM summarize → content truncate → middle-out delete
|
||||
→ first/last trim). ``target_tokens`` decreases on each retry to force
|
||||
progressively more aggressive compression when the first attempt exceeds
|
||||
context limits.
|
||||
|
||||
Returns:
|
||||
Tuple of (query_message, was_compacted).
|
||||
"""
|
||||
msg_count = len(session.messages)
|
||||
prior = session.messages[:-1] # all turns except the current user message
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Context path: use_resume=%s, transcript_msg_count=%d,"
|
||||
" db_msg_count=%d, target_tokens=%s",
|
||||
session_id[:8],
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
target_tokens,
|
||||
)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
if transcript_msg_count < msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
compressed, was_compressed = await _compress_messages(gap)
|
||||
# Sanity-check the watermark: the last covered position should be
|
||||
# an assistant turn. A user-role message here means the count is
|
||||
# misaligned (e.g. a message was deleted and DB positions shifted).
|
||||
# Skip the gap rather than injecting wrong context — the CLI session
|
||||
# loaded via --resume still has good history.
|
||||
if prior[transcript_msg_count - 1].role != "assistant":
|
||||
logger.warning(
|
||||
"[SDK] [%s] Watermark misaligned: prior[%d].role=%r"
|
||||
" (expected 'assistant') — skipping gap to avoid"
|
||||
" injecting wrong context (transcript=%d, db=%d)",
|
||||
session_id[:8],
|
||||
transcript_msg_count - 1,
|
||||
prior[transcript_msg_count - 1].role,
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
)
|
||||
return current_message, False
|
||||
gap = prior[transcript_msg_count:]
|
||||
compressed, was_compressed = await _compress_messages(gap, target_tokens)
|
||||
gap_context = _format_conversation_context(compressed)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
"[SDK] Transcript stale: covers %d of %d messages, "
|
||||
"gap=%d (compressed=%s)",
|
||||
"gap=%d (compressed=%s), gap_context_bytes=%d",
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
len(gap),
|
||||
was_compressed,
|
||||
len(gap_context),
|
||||
)
|
||||
return (
|
||||
f"{gap_context}\n\nNow, the user says:\n{current_message}",
|
||||
was_compressed,
|
||||
)
|
||||
logger.warning(
|
||||
"[SDK] [%s] Transcript stale: gap produced empty context"
|
||||
" (%d msgs, transcript=%d/%d) — sending message without gap prefix",
|
||||
session_id[:8],
|
||||
len(gap),
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"[SDK] [%s] --resume covers full context (%d messages)",
|
||||
session_id[:8],
|
||||
transcript_msg_count,
|
||||
)
|
||||
return current_message, False
|
||||
|
||||
elif not use_resume and msg_count > 1:
|
||||
# No --resume: the CLI starts a fresh session with no prior context.
|
||||
# Injecting only the post-transcript gap would omit the transcript-covered
|
||||
# prefix entirely, so always compress the full prior session here.
|
||||
# compress_context handles size reduction internally (LLM summarize →
|
||||
# content truncate → middle-out delete → first/last trim).
|
||||
|
||||
# Final escape hatch: if the token budget is at or below the floor,
|
||||
# the model context is so tight that even fully compressed history
|
||||
# would risk a "prompt too long" error. Return the bare message so
|
||||
# the user always gets a response rather than a hard failure.
|
||||
if target_tokens is not None and target_tokens <= _BARE_MESSAGE_TOKEN_FLOOR:
|
||||
logger.warning(
|
||||
"[SDK] [%s] target_tokens=%d at or below floor (%d) —"
|
||||
" skipping history injection to guarantee response delivery"
|
||||
" (session has %d messages)",
|
||||
session_id[:8],
|
||||
target_tokens,
|
||||
_BARE_MESSAGE_TOKEN_FLOOR,
|
||||
msg_count,
|
||||
)
|
||||
return current_message, False
|
||||
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({msg_count} messages) — no transcript for --resume"
|
||||
"[SDK] [%s] No --resume for %d-message session — compressing"
|
||||
" full session history (pod affinity issue or first turn after"
|
||||
" restore failure); target_tokens=%s",
|
||||
session_id[:8],
|
||||
msg_count,
|
||||
target_tokens,
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(session.messages[:-1])
|
||||
compressed, was_compressed = await _compress_messages(prior, target_tokens)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
logger.info(
|
||||
"[SDK] [%s] Fallback context built: compressed=%s," " context_bytes=%d",
|
||||
session_id[:8],
|
||||
was_compressed,
|
||||
len(history_context),
|
||||
)
|
||||
return (
|
||||
f"{history_context}\n\nNow, the user says:\n{current_message}",
|
||||
was_compressed,
|
||||
)
|
||||
logger.warning(
|
||||
"[SDK] [%s] Fallback context empty after compression"
|
||||
" (%d messages) — sending message without history",
|
||||
session_id[:8],
|
||||
len(prior),
|
||||
)
|
||||
|
||||
return current_message, False
|
||||
|
||||
@@ -1688,15 +1913,20 @@ async def _run_stream_attempt(
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||
state.usage.cache_read_tokens += sdk_msg.usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
# Use `or 0` instead of a default in .get() because
|
||||
# OpenRouter may include the key with a null value (e.g.
|
||||
# {"cache_read_input_tokens": null}) for models that don't
|
||||
# yet report cache tokens, making .get("key", 0) return
|
||||
# None rather than the fallback 0.
|
||||
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0
|
||||
state.usage.cache_read_tokens += (
|
||||
sdk_msg.usage.get("cache_read_input_tokens") or 0
|
||||
)
|
||||
state.usage.cache_creation_tokens += sdk_msg.usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
state.usage.cache_creation_tokens += (
|
||||
sdk_msg.usage.get("cache_creation_input_tokens") or 0
|
||||
)
|
||||
state.usage.completion_tokens += sdk_msg.usage.get(
|
||||
"output_tokens", 0
|
||||
state.usage.completion_tokens += (
|
||||
sdk_msg.usage.get("output_tokens") or 0
|
||||
)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, "
|
||||
@@ -1922,6 +2152,48 @@ async def _run_stream_attempt(
|
||||
)
|
||||
|
||||
|
||||
async def _seed_transcript(
|
||||
session: ChatSession,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
transcript_covers_prefix: bool,
|
||||
transcript_msg_count: int,
|
||||
log_prefix: str,
|
||||
) -> tuple[str, bool, int]:
|
||||
"""Seed the transcript builder from compressed DB messages.
|
||||
|
||||
Called when ``use_resume=False`` and no prior transcript exists in storage
|
||||
so that ``upload_transcript`` saves a compact version for future turns.
|
||||
This ensures the next turn can use the full-session compression path with
|
||||
the benefit of an already-compressed baseline, and a restored CLI session
|
||||
on the next pod gets a usable compact base even for sessions that started
|
||||
on old pods.
|
||||
|
||||
Returns ``(transcript_content, transcript_covers_prefix, transcript_msg_count)``
|
||||
updated values — unchanged if seeding is not possible.
|
||||
"""
|
||||
if len(session.messages) <= 1:
|
||||
return "", transcript_covers_prefix, transcript_msg_count
|
||||
|
||||
_prior = session.messages[:-1]
|
||||
_comp, _ = await _compress_messages(_prior, _SEED_TARGET_TOKENS)
|
||||
if not _comp:
|
||||
return "", transcript_covers_prefix, transcript_msg_count
|
||||
|
||||
_seeded = _session_messages_to_transcript(_comp)
|
||||
if not _seeded or not validate_transcript(_seeded):
|
||||
return "", transcript_covers_prefix, transcript_msg_count
|
||||
|
||||
transcript_builder.load_previous(_seeded, log_prefix=log_prefix)
|
||||
logger.info(
|
||||
"%s Seeded transcript from %d compressed DB messages"
|
||||
" for next-turn upload (seed_target_tokens=%d)",
|
||||
log_prefix,
|
||||
len(_comp),
|
||||
_SEED_TARGET_TOKENS,
|
||||
)
|
||||
return _seeded, True, len(_prior)
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -1931,6 +2203,7 @@ async def stream_chat_completion_sdk(
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -1941,6 +2214,9 @@ async def stream_chat_completion_sdk(
|
||||
saved to the SDK working directory for the Read tool.
|
||||
mode: Accepted for signature compatibility with the baseline path.
|
||||
The SDK path does not currently branch on this value.
|
||||
model: Per-request model preference from the frontend toggle.
|
||||
'advanced' → Claude Opus; 'standard' → global config default.
|
||||
Takes priority over per-user LaunchDarkly targeting.
|
||||
"""
|
||||
_ = mode # SDK path ignores the requested mode.
|
||||
|
||||
@@ -2055,6 +2331,10 @@ async def stream_chat_completion_sdk(
|
||||
turn_cache_creation_tokens = 0
|
||||
turn_cost_usd: float | None = None
|
||||
graphiti_enabled = False
|
||||
# Defaults ensure the finally block can always reference these safely even when
|
||||
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
|
||||
sdk_model: str | None = None
|
||||
model_cost_multiplier: float = 1.0
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
try:
|
||||
@@ -2145,7 +2425,7 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
# Warm context: pre-load relevant facts from Graphiti on first turn
|
||||
if graphiti_enabled and user_id and len(session.messages) <= 1:
|
||||
from backend.copilot.graphiti.context import fetch_warm_context
|
||||
from ..graphiti.context import fetch_warm_context
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "")
|
||||
if warm_ctx:
|
||||
@@ -2193,9 +2473,20 @@ async def stream_chat_completion_sdk(
|
||||
# Builder loaded but CLI native session not available.
|
||||
# --resume will not be used this turn; upload after turn
|
||||
# will seed the native session for the next turn.
|
||||
#
|
||||
# Still record transcript_msg_count so _build_query_message
|
||||
# can use the transcript-aware gap path (inject only new
|
||||
# messages since the transcript end) instead of compressing
|
||||
# the full DB history. This avoids prompt-too-long on
|
||||
# large sessions where the CLI session is temporarily
|
||||
# unavailable (e.g. mixed-version rolling deployment).
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.info(
|
||||
"%s CLI session not restored — running without --resume this turn",
|
||||
"%s CLI session not restored — running without"
|
||||
" --resume this turn (transcript_msg_count=%d for"
|
||||
" gap-aware fallback)",
|
||||
log_prefix,
|
||||
transcript_msg_count,
|
||||
)
|
||||
else:
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
@@ -2255,7 +2546,10 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
# Resolve model and cost multiplier (request tier → config default).
|
||||
sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier(
|
||||
model, session_id
|
||||
)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
@@ -2290,8 +2584,19 @@ async def stream_chat_completion_sdk(
|
||||
sid,
|
||||
)
|
||||
|
||||
# Use SystemPromptPreset for cross-user prompt caching.
|
||||
# WORKAROUND: CLI 2.1.97 (sdk 0.1.58) exits code 1 when
|
||||
# excludeDynamicSections=True is in the initialize request AND
|
||||
# --resume is active. Disable the preset on resumed turns.
|
||||
# Turn 1 still gets the preset (no --resume).
|
||||
_cross_user = config.claude_agent_cross_user_prompt_cache and not use_resume
|
||||
system_prompt_value = _build_system_prompt_value(
|
||||
system_prompt,
|
||||
cross_user_cache=_cross_user,
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"system_prompt": system_prompt_value,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": allowed,
|
||||
"disallowed_tools": disallowed,
|
||||
@@ -2330,13 +2635,19 @@ async def stream_chat_completion_sdk(
|
||||
# --session-id here. CLI >=2.1.97 rejects the combination of
|
||||
# --session-id + --resume unless --fork-session is also given.
|
||||
sdk_options_kwargs["resume"] = resume_file
|
||||
elif not has_history:
|
||||
# T1 only: write CLI native session to a predictable path so
|
||||
# upload_cli_session() can find it after the turn completes.
|
||||
# On T2+ without --resume the T1 session file already exists at
|
||||
# that path; passing --session-id again would fail with
|
||||
# "Session ID already in use". The upload guard also skips T2+
|
||||
# no-resume turns, so --session-id provides no benefit there.
|
||||
else:
|
||||
# Set session_id whenever NOT resuming so the CLI writes the
|
||||
# native session file to a predictable path for
|
||||
# upload_cli_session() after the turn. This covers:
|
||||
# • T1 fresh: no prior history, first SDK turn.
|
||||
# • Mode-switch T1: has_history=True (prior baseline turns in
|
||||
# DB) but no CLI session file was ever uploaded — the CLI has
|
||||
# never been invoked with this session_id before.
|
||||
# • T2+ without --resume (restore failed): no session file was
|
||||
# restored to local storage (restore_cli_session returned
|
||||
# False), so no conflict with an existing file.
|
||||
# When --resume is active the session_id is already implied by
|
||||
# the resume file; passing it again would be rejected by the CLI.
|
||||
sdk_options_kwargs["session_id"] = session_id
|
||||
# Optional explicit Claude Code CLI binary path (decouples the
|
||||
# bundled SDK version from the CLI version we run — needed because
|
||||
@@ -2420,6 +2731,22 @@ async def stream_chat_completion_sdk(
|
||||
if attachments.hint:
|
||||
query_message = f"{query_message}\n\n{attachments.hint}"
|
||||
|
||||
# When running without --resume and no prior transcript in storage,
|
||||
# seed the transcript builder from compressed DB messages so that
|
||||
# upload_transcript saves a compact version for future turns.
|
||||
if not use_resume and not transcript_content and not skip_transcript_upload:
|
||||
(
|
||||
transcript_content,
|
||||
transcript_covers_prefix,
|
||||
transcript_msg_count,
|
||||
) = await _seed_transcript(
|
||||
session,
|
||||
transcript_builder,
|
||||
transcript_covers_prefix,
|
||||
transcript_msg_count,
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
tried_compaction = False
|
||||
|
||||
# Build the per-request context carrier (shared across attempts).
|
||||
@@ -2502,12 +2829,14 @@ async def stream_chat_completion_sdk(
|
||||
session_id,
|
||||
sdk_cwd,
|
||||
log_prefix,
|
||||
attempt=attempt,
|
||||
)
|
||||
state.transcript_builder = ctx.builder
|
||||
state.use_resume = ctx.use_resume
|
||||
state.resume_file = ctx.resume_file
|
||||
tried_compaction = ctx.tried_compaction
|
||||
state.transcript_msg_count = 0
|
||||
state.target_tokens = ctx.target_tokens
|
||||
if ctx.transcript_lost:
|
||||
skip_transcript_upload = True
|
||||
|
||||
@@ -2516,18 +2845,30 @@ async def stream_chat_completion_sdk(
|
||||
if ctx.use_resume and ctx.resume_file:
|
||||
sdk_options_kwargs_retry["resume"] = ctx.resume_file
|
||||
sdk_options_kwargs_retry.pop("session_id", None)
|
||||
elif not has_history:
|
||||
# T1 retry: keep session_id so the CLI writes to the
|
||||
# predictable path for upload_cli_session().
|
||||
elif "session_id" in sdk_options_kwargs:
|
||||
# Initial invocation used session_id (T1 or mode-switch
|
||||
# T1): keep it so the CLI writes the session file to the
|
||||
# predictable path for upload_cli_session(). Storage is
|
||||
# ephemeral per invocation, so no "Session ID already in
|
||||
# use" conflict occurs — no prior file was restored.
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
sdk_options_kwargs_retry["session_id"] = session_id
|
||||
else:
|
||||
# T2+ retry without --resume: do not pass --session-id.
|
||||
# The T1 session file already exists at that path; re-using
|
||||
# the same ID would fail with "Session ID already in use".
|
||||
# The upload guard skips T2+ no-resume turns anyway.
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
sdk_options_kwargs_retry.pop("session_id", None)
|
||||
# Recompute system_prompt for retry — ctx.use_resume may have
|
||||
# changed (context reduction enabled --resume). CLI 2.1.97
|
||||
# crashes when excludeDynamicSections=True is combined with
|
||||
# --resume, so disable the cross-user preset on resumed turns.
|
||||
_cross_user_retry = (
|
||||
config.claude_agent_cross_user_prompt_cache and not ctx.use_resume
|
||||
)
|
||||
sdk_options_kwargs_retry["system_prompt"] = _build_system_prompt_value(
|
||||
system_prompt, cross_user_cache=_cross_user_retry
|
||||
)
|
||||
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
|
||||
state.query_message, state.was_compacted = await _build_query_message(
|
||||
current_message,
|
||||
@@ -2535,6 +2876,7 @@ async def stream_chat_completion_sdk(
|
||||
state.use_resume,
|
||||
state.transcript_msg_count,
|
||||
session_id,
|
||||
target_tokens=state.target_tokens,
|
||||
)
|
||||
if attachments.hint:
|
||||
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
|
||||
@@ -2901,8 +3243,9 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
model=sdk_model or config.model,
|
||||
provider="anthropic",
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
@@ -2939,7 +3282,7 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
# --- Graphiti: ingest conversation turn for temporal memory ---
|
||||
if graphiti_enabled and user_id and message and is_user_message:
|
||||
from backend.copilot.graphiti.ingest import enqueue_conversation_turn
|
||||
from ..graphiti.ingest import enqueue_conversation_turn
|
||||
|
||||
_ingest_task = asyncio.create_task(
|
||||
enqueue_conversation_turn(user_id, session_id, message)
|
||||
@@ -3020,6 +3363,21 @@ async def stream_chat_completion_sdk(
|
||||
# the shielded inner coroutine continues running to completion so the
|
||||
# upload is not lost. This is intentional and matches the pattern
|
||||
# used for upload_transcript immediately above.
|
||||
#
|
||||
# NOTE: upload is attempted regardless of state.use_resume — even when
|
||||
# this turn ran without --resume (restore failed or first T2+ on a new
|
||||
# pod), the T1 session file at the expected path may still be present
|
||||
# and should be re-uploaded so the next turn can resume from it.
|
||||
# upload_cli_session silently skips when the file is absent, so this is
|
||||
# always safe.
|
||||
#
|
||||
# Intentionally NOT gated on skip_transcript_upload: that flag is set
|
||||
# when our custom JSONL transcript is dropped (transcript_lost=True on
|
||||
# reduced-context retries) but the CLI's native session file is written
|
||||
# independently. Blocking CLI upload on transcript_lost would prevent
|
||||
# T1 prompt-too-long retries from uploading their valid session file,
|
||||
# breaking --resume on the next pod. The ended_with_stream_error gate
|
||||
# above already covers actual turn failures.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
@@ -3027,9 +3385,15 @@ async def stream_chat_completion_sdk(
|
||||
and session is not None
|
||||
and state is not None
|
||||
and not ended_with_stream_error
|
||||
and not skip_transcript_upload
|
||||
and (not has_history or state.use_resume)
|
||||
):
|
||||
logger.info(
|
||||
"%s Attempting CLI session upload"
|
||||
" (use_resume=%s, has_history=%s, skip_transcript=%s)",
|
||||
log_prefix,
|
||||
state.use_resume,
|
||||
has_history,
|
||||
skip_transcript_upload,
|
||||
)
|
||||
try:
|
||||
await asyncio.shield(
|
||||
upload_cli_session(
|
||||
|
||||
@@ -15,11 +15,14 @@ from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
_RETRY_TARGET_TOKENS,
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_TokenUsage,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -207,6 +210,24 @@ class TestReduceContext:
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_returns_target_tokens_attempt_1(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=1)
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_returns_target_tokens_attempt_2(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=2)
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_clamps_attempt_beyond_limits(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=99)
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[-1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
@@ -331,3 +352,264 @@ class TestIsParallelContinuation:
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_model_name — used by per-request model override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Unit tests for the model-name normalisation helper.
|
||||
|
||||
The per-request model toggle calls _normalize_model_name with either
|
||||
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
|
||||
'standard'). These tests verify the OpenRouter/provider-prefix stripping
|
||||
that keeps the value compatible with the Claude CLI.
|
||||
"""
|
||||
|
||||
def test_strips_anthropic_prefix(self):
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_strips_openai_prefix(self):
|
||||
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
|
||||
|
||||
def test_strips_google_prefix(self):
|
||||
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
|
||||
|
||||
def test_already_normalized_unchanged(self):
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
def test_empty_string_unchanged(self):
|
||||
assert _normalize_model_name("") == ""
|
||||
|
||||
def test_opus_model_roundtrip(self):
|
||||
"""The exact string used for the 'opus' toggle strips correctly."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_sonnet_openrouter_model(self):
|
||||
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
|
||||
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenUsageNullSafety:
|
||||
"""Verify that ResultMessage.usage dicts with null-valued cache fields
|
||||
(as emitted by OpenRouter for the initial streaming event before real
|
||||
token counts are available) do not crash the accumulator.
|
||||
|
||||
Before the fix, dict.get("cache_read_input_tokens", 0) returned None
|
||||
when the key existed with a null value, causing 'int += None' TypeError.
|
||||
"""
|
||||
|
||||
def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None:
|
||||
"""Null-safe accumulation: ``or 0`` treats missing/None as zero.
|
||||
|
||||
Uses ``usage.get("key") or 0`` rather than ``usage.get("key", 0)``
|
||||
because the latter returns ``None`` when the key exists with a null
|
||||
value, which would raise ``TypeError`` on ``int += None``. This is
|
||||
the intentional pattern that fixes the OpenRouter initial-stream-event
|
||||
bug described in the class docstring.
|
||||
"""
|
||||
acc.prompt_tokens += usage.get("input_tokens") or 0
|
||||
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
|
||||
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
|
||||
acc.completion_tokens += usage.get("output_tokens") or 0
|
||||
|
||||
def test_null_cache_tokens_do_not_crash(self):
|
||||
"""OpenRouter initial event: cache keys present with null value."""
|
||||
usage = {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_input_tokens": None,
|
||||
"cache_creation_input_tokens": None,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc) # must not raise TypeError
|
||||
assert acc.prompt_tokens == 0
|
||||
assert acc.cache_read_tokens == 0
|
||||
assert acc.cache_creation_tokens == 0
|
||||
assert acc.completion_tokens == 0
|
||||
|
||||
def test_real_cache_tokens_are_accumulated(self):
|
||||
"""OpenRouter final event: real cache token counts are captured."""
|
||||
usage = {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 349,
|
||||
"cache_read_input_tokens": 16600,
|
||||
"cache_creation_input_tokens": 512,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc)
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.cache_read_tokens == 16600
|
||||
assert acc.cache_creation_tokens == 512
|
||||
assert acc.completion_tokens == 349
|
||||
|
||||
def test_absent_cache_keys_default_to_zero(self):
|
||||
"""Minimal usage dict without cache keys defaults correctly."""
|
||||
usage = {"input_tokens": 5, "output_tokens": 20}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc)
|
||||
assert acc.prompt_tokens == 5
|
||||
assert acc.cache_read_tokens == 0
|
||||
assert acc.cache_creation_tokens == 0
|
||||
assert acc.completion_tokens == 20
|
||||
|
||||
def test_multi_turn_accumulation(self):
|
||||
"""Null event followed by real event: only real tokens counted."""
|
||||
null_event = {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_input_tokens": None,
|
||||
"cache_creation_input_tokens": None,
|
||||
}
|
||||
real_event = {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 349,
|
||||
"cache_read_input_tokens": 16600,
|
||||
"cache_creation_input_tokens": 512,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(null_event, acc)
|
||||
self._apply_usage(real_event, acc)
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.cache_read_tokens == 16600
|
||||
assert acc.cache_creation_tokens == 512
|
||||
assert acc.completion_tokens == 349
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_id / resume selection logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_sdk_options(
|
||||
use_resume: bool,
|
||||
resume_file: str | None,
|
||||
session_id: str,
|
||||
) -> dict:
|
||||
"""Mirror the session_id/resume selection in stream_chat_completion_sdk.
|
||||
|
||||
This helper encodes the exact branching so the unit tests stay in sync
|
||||
with the production code without needing to invoke the full generator.
|
||||
"""
|
||||
kwargs: dict = {}
|
||||
if use_resume and resume_file:
|
||||
kwargs["resume"] = resume_file
|
||||
else:
|
||||
kwargs["session_id"] = session_id
|
||||
return kwargs
|
||||
|
||||
|
||||
def _build_retry_sdk_options(
|
||||
initial_kwargs: dict,
|
||||
ctx_use_resume: bool,
|
||||
ctx_resume_file: str | None,
|
||||
session_id: str,
|
||||
) -> dict:
|
||||
"""Mirror the retry branch in stream_chat_completion_sdk."""
|
||||
retry: dict = dict(initial_kwargs)
|
||||
if ctx_use_resume and ctx_resume_file:
|
||||
retry["resume"] = ctx_resume_file
|
||||
retry.pop("session_id", None)
|
||||
elif "session_id" in initial_kwargs:
|
||||
retry.pop("resume", None)
|
||||
retry["session_id"] = session_id
|
||||
else:
|
||||
retry.pop("resume", None)
|
||||
retry.pop("session_id", None)
|
||||
return retry
|
||||
|
||||
|
||||
class TestSdkSessionIdSelection:
|
||||
"""Verify that session_id is set for all non-resume turns.
|
||||
|
||||
Regression test for the mode-switch T1 bug: when a user switches from
|
||||
baseline mode (fast) to SDK mode (extended_thinking) mid-session, the
|
||||
first SDK turn has has_history=True but no CLI session file. The old
|
||||
code gated session_id on ``not has_history``, so mode-switch T1 never
|
||||
got a session_id — the CLI used a random ID that couldn't be found on
|
||||
the next turn, causing --resume to fail for the whole session.
|
||||
"""
|
||||
|
||||
SESSION_ID = "sess-abc123"
|
||||
|
||||
def test_t1_fresh_sets_session_id(self):
|
||||
"""T1 of a fresh session always gets session_id."""
|
||||
opts = _build_sdk_options(
|
||||
use_resume=False,
|
||||
resume_file=None,
|
||||
session_id=self.SESSION_ID,
|
||||
)
|
||||
assert opts.get("session_id") == self.SESSION_ID
|
||||
assert "resume" not in opts
|
||||
|
||||
def test_mode_switch_t1_sets_session_id(self):
|
||||
"""Mode-switch T1 (has_history=True, no CLI session) gets session_id.
|
||||
|
||||
Before the fix, the ``elif not has_history`` guard prevented this
|
||||
case from setting session_id, causing all subsequent turns to run
|
||||
without --resume.
|
||||
"""
|
||||
# Mode-switch T1: use_resume=False (no prior CLI session) and
|
||||
# has_history=True (prior baseline turns in DB). The old code
|
||||
# (``elif not has_history``) silently skipped this case.
|
||||
opts = _build_sdk_options(
|
||||
use_resume=False,
|
||||
resume_file=None,
|
||||
session_id=self.SESSION_ID,
|
||||
)
|
||||
assert opts.get("session_id") == self.SESSION_ID
|
||||
assert "resume" not in opts
|
||||
|
||||
def test_t2_with_resume_uses_resume(self):
|
||||
"""T2+ with a restored CLI session uses --resume, not session_id."""
|
||||
opts = _build_sdk_options(
|
||||
use_resume=True,
|
||||
resume_file=self.SESSION_ID,
|
||||
session_id=self.SESSION_ID,
|
||||
)
|
||||
assert opts.get("resume") == self.SESSION_ID
|
||||
assert "session_id" not in opts
|
||||
|
||||
def test_t2_without_resume_sets_session_id(self):
|
||||
"""T2+ when restore failed still gets session_id (no prior file on disk)."""
|
||||
opts = _build_sdk_options(
|
||||
use_resume=False,
|
||||
resume_file=None,
|
||||
session_id=self.SESSION_ID,
|
||||
)
|
||||
assert opts.get("session_id") == self.SESSION_ID
|
||||
assert "resume" not in opts
|
||||
|
||||
def test_retry_keeps_session_id_for_t1(self):
|
||||
"""Retry for T1 (or mode-switch T1) preserves session_id."""
|
||||
initial = _build_sdk_options(False, None, self.SESSION_ID)
|
||||
retry = _build_retry_sdk_options(initial, False, None, self.SESSION_ID)
|
||||
assert retry.get("session_id") == self.SESSION_ID
|
||||
assert "resume" not in retry
|
||||
|
||||
def test_retry_removes_session_id_for_t2_plus(self):
|
||||
"""Retry for T2+ (initial used --resume) removes session_id to avoid conflict."""
|
||||
initial = _build_sdk_options(True, self.SESSION_ID, self.SESSION_ID)
|
||||
# T2+ retry where context reduction dropped --resume
|
||||
retry = _build_retry_sdk_options(initial, False, None, self.SESSION_ID)
|
||||
assert "session_id" not in retry
|
||||
assert "resume" not in retry
|
||||
|
||||
def test_retry_t2_with_resume_sets_resume(self):
|
||||
"""Retry that still uses --resume keeps --resume and drops session_id."""
|
||||
initial = _build_sdk_options(True, self.SESSION_ID, self.SESSION_ID)
|
||||
retry = _build_retry_sdk_options(
|
||||
initial, True, self.SESSION_ID, self.SESSION_ID
|
||||
)
|
||||
assert retry.get("resume") == self.SESSION_ID
|
||||
assert "session_id" not in retry
|
||||
|
||||
@@ -8,7 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
from .service import (
|
||||
_build_system_prompt_value,
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
@@ -397,6 +400,7 @@ _CONFIG_ENV_VARS = (
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
|
||||
)
|
||||
|
||||
|
||||
@@ -656,3 +660,62 @@ class TestSafeCloseSdkClient:
|
||||
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
|
||||
with pytest.raises(ValueError, match="invalid argument"):
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SystemPromptPreset — cross-user prompt caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSystemPromptPreset:
|
||||
"""Tests for _build_system_prompt_value — cross-user prompt caching."""
|
||||
|
||||
def test_preset_dict_structure_when_enabled(self):
|
||||
"""When cross_user_cache is True, returns a _SystemPromptPreset dict."""
|
||||
custom_prompt = "You are a helpful assistant."
|
||||
result = _build_system_prompt_value(custom_prompt, cross_user_cache=True)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["type"] == "preset"
|
||||
assert result["preset"] == "claude_code"
|
||||
assert result["append"] == custom_prompt
|
||||
assert result["exclude_dynamic_sections"] is True
|
||||
|
||||
def test_raw_string_when_disabled(self):
|
||||
"""When cross_user_cache is False, returns the raw string."""
|
||||
custom_prompt = "You are a helpful assistant."
|
||||
result = _build_system_prompt_value(custom_prompt, cross_user_cache=False)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == custom_prompt
|
||||
|
||||
def test_empty_string_with_cache_enabled(self):
|
||||
"""Empty system_prompt with cross_user_cache=True produces append=''."""
|
||||
result = _build_system_prompt_value("", cross_user_cache=True)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["type"] == "preset"
|
||||
assert result["preset"] == "claude_code"
|
||||
assert result["append"] == ""
|
||||
assert result["exclude_dynamic_sections"] is True
|
||||
|
||||
def test_default_config_is_enabled(self, _clean_config_env):
|
||||
"""The default value for claude_agent_cross_user_prompt_cache is True."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is True
|
||||
|
||||
def test_env_var_disables_cache(self, _clean_config_env, monkeypatch):
|
||||
"""CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE=false disables caching."""
|
||||
monkeypatch.setenv("CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE", "false")
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is False
|
||||
|
||||
@@ -960,7 +960,7 @@ class TestRunCompression:
|
||||
)
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(*, messages, model, client):
|
||||
async def _compress_side_effect(*, messages, model, client, target_tokens=None):
|
||||
call_count[0] += 1
|
||||
if client is not None:
|
||||
# Simulate a hang that exceeds the timeout
|
||||
|
||||
@@ -1149,3 +1149,50 @@ async def unsubscribe_from_session(
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from session {session_id}")
|
||||
|
||||
|
||||
async def disconnect_all_listeners(session_id: str) -> int:
|
||||
"""Cancel every active listener task for *session_id*.
|
||||
|
||||
Called when the frontend switches away from a session and wants the
|
||||
backend to release resources immediately rather than waiting for the
|
||||
XREAD timeout.
|
||||
|
||||
Scope / limitations (best-effort optimisation, not a correctness primitive):
|
||||
- Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request
|
||||
lands on a different worker than the one serving the SSE, no listener
|
||||
is cancelled here — the SSE worker still releases on its XREAD timeout.
|
||||
- Session-scoped (not subscriber-scoped): cancels every active listener
|
||||
for the session on this pod. In the rare case a single user opens two
|
||||
SSE connections to the same session on the same pod (e.g. two tabs),
|
||||
both would be torn down. Cross-pod, subscriber-scoped cancellation
|
||||
would require a Redis pub/sub fan-out with per-listener tokens; that
|
||||
is not implemented here because the XREAD timeout already bounds the
|
||||
worst case.
|
||||
|
||||
Returns the number of listener tasks that were cancelled.
|
||||
"""
|
||||
to_cancel: list[tuple[int, asyncio.Task]] = [
|
||||
(qid, task)
|
||||
for qid, (sid, task) in list(_listener_sessions.items())
|
||||
if sid == session_id and not task.done()
|
||||
]
|
||||
|
||||
for qid, task in to_cancel:
|
||||
_listener_sessions.pop(qid, None)
|
||||
task.cancel()
|
||||
|
||||
cancelled = 0
|
||||
for _qid, task in to_cancel:
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=5.0)
|
||||
except asyncio.CancelledError:
|
||||
cancelled += 1
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling listener for session {session_id}: {e}")
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}")
|
||||
return cancelled
|
||||
|
||||
110
autogpt_platform/backend/backend/copilot/stream_registry_test.py
Normal file
110
autogpt_platform/backend/backend/copilot/stream_registry_test.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for disconnect_all_listeners in stream_registry."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_listener_sessions():
|
||||
stream_registry._listener_sessions.clear()
|
||||
yield
|
||||
stream_registry._listener_sessions.clear()
|
||||
|
||||
|
||||
async def _sleep_forever():
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_cancels_matching_session():
|
||||
task_a = asyncio.create_task(_sleep_forever())
|
||||
task_b = asyncio.create_task(_sleep_forever())
|
||||
task_other = asyncio.create_task(_sleep_forever())
|
||||
|
||||
stream_registry._listener_sessions[1] = ("sess-1", task_a)
|
||||
stream_registry._listener_sessions[2] = ("sess-1", task_b)
|
||||
stream_registry._listener_sessions[3] = ("sess-other", task_other)
|
||||
|
||||
try:
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
assert cancelled == 2
|
||||
assert task_a.cancelled()
|
||||
assert task_b.cancelled()
|
||||
assert not task_other.done()
|
||||
# Matching entries are removed, non-matching entries remain.
|
||||
assert 1 not in stream_registry._listener_sessions
|
||||
assert 2 not in stream_registry._listener_sessions
|
||||
assert 3 in stream_registry._listener_sessions
|
||||
finally:
|
||||
task_other.cancel()
|
||||
try:
|
||||
await task_other
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_no_match_returns_zero():
|
||||
task = asyncio.create_task(_sleep_forever())
|
||||
stream_registry._listener_sessions[1] = ("sess-other", task)
|
||||
|
||||
try:
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-missing")
|
||||
|
||||
assert cancelled == 0
|
||||
assert not task.done()
|
||||
assert 1 in stream_registry._listener_sessions
|
||||
finally:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_skips_already_done_tasks():
|
||||
async def _noop():
|
||||
return None
|
||||
|
||||
done_task = asyncio.create_task(_noop())
|
||||
await done_task
|
||||
stream_registry._listener_sessions[1] = ("sess-1", done_task)
|
||||
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
# Done tasks are filtered out before cancellation.
|
||||
assert cancelled == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_empty_registry():
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
assert cancelled == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_timeout_not_counted():
|
||||
"""Tasks that don't respond to cancellation (timeout) are not counted."""
|
||||
task = asyncio.create_task(_sleep_forever())
|
||||
stream_registry._listener_sessions[1] = ("sess-1", task)
|
||||
|
||||
with patch.object(
|
||||
asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
):
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
assert cancelled == 0
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@@ -96,6 +96,7 @@ async def persist_and_record_usage(
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
@@ -109,6 +110,9 @@ async def persist_and_record_usage(
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
model_cost_multiplier: Relative model cost factor for rate limiting
|
||||
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
|
||||
more expensive models deplete the rate limit proportionally faster.
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -163,6 +167,7 @@ async def persist_and_record_usage(
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
|
||||
@@ -230,6 +230,7 @@ class TestRateLimitRecording:
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
model_cost_multiplier=1.0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -74,6 +74,15 @@ class FindBlockTool(BaseTool):
|
||||
"description": "Include full input/output schemas (for agent JSON generation).",
|
||||
"default": False,
|
||||
},
|
||||
"for_agent_generation": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Set to true when searching for blocks to use inside an agent graph "
|
||||
"(e.g. AgentInputBlock, AgentOutputBlock, OrchestratorBlock). "
|
||||
"Bypasses the CoPilot-only filter so graph-only blocks are visible."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
@@ -88,6 +97,7 @@ class FindBlockTool(BaseTool):
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
include_schemas: bool = False,
|
||||
for_agent_generation: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
@@ -97,6 +107,8 @@ class FindBlockTool(BaseTool):
|
||||
session: Chat session
|
||||
query: Search query
|
||||
include_schemas: Whether to include block schemas in results
|
||||
for_agent_generation: When True, bypasses the CoPilot exclusion filter
|
||||
so graph-only blocks (INPUT, OUTPUT, ORCHESTRATOR, etc.) are visible.
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
@@ -123,34 +135,36 @@ class FindBlockTool(BaseTool):
|
||||
suggestions=["Search for an alternative block by name"],
|
||||
session_id=session_id,
|
||||
)
|
||||
if (
|
||||
is_excluded = (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
)
|
||||
if is_excluded:
|
||||
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
|
||||
# exposed when building an agent graph so the LLM can inspect
|
||||
# their schemas and wire them as nodes. In CoPilot direct use
|
||||
# they are not executable — guide the LLM to the right tool.
|
||||
if not for_agent_generation:
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
message = (
|
||||
f"Block '{block.name}' (ID: {block.id}) cannot be "
|
||||
"run directly in CoPilot. Use run_mcp_tool for "
|
||||
"interactive MCP execution, or call find_block with "
|
||||
"for_agent_generation=true to embed it in an agent graph."
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Block '{block.name}' (ID: {block.id}) is not available "
|
||||
"in CoPilot. It can only be used within agent graphs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not "
|
||||
"runnable through find_block/run_block. Use "
|
||||
"run_mcp_tool instead."
|
||||
),
|
||||
message=message,
|
||||
suggestions=[
|
||||
"Use run_mcp_tool to discover and run this MCP tool",
|
||||
"Search for an alternative block by name",
|
||||
"Use this block in an agent graph instead",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not available "
|
||||
"in CoPilot. It can only be used within agent graphs."
|
||||
),
|
||||
suggestions=[
|
||||
"Search for an alternative block by name",
|
||||
"Use this block in an agent graph instead",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check block-level permissions — hide denied blocks entirely
|
||||
perms = get_current_permissions()
|
||||
@@ -221,8 +235,9 @@ class FindBlockTool(BaseTool):
|
||||
if not block or block.disabled:
|
||||
continue
|
||||
|
||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
|
||||
# skipped in CoPilot direct use but surfaced for agent graph building.
|
||||
if not for_agent_generation and (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
|
||||
@@ -12,7 +12,7 @@ from .find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from .models import BlockListResponse
|
||||
from .models import BlockListResponse, NoResultsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
@@ -166,6 +166,194 @@ class TestFindBlockFiltering:
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_for_agent_generation_exposes_excluded_blocks_in_search(self):
|
||||
"""With for_agent_generation=True, excluded block types appear in search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "output-block-id", "score": 0.8},
|
||||
]
|
||||
input_block = make_mock_block("input-block-id", "Agent Input", BlockType.INPUT)
|
||||
output_block = make_mock_block(
|
||||
"output-block-id", "Agent Output", BlockType.OUTPUT
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"output-block-id": output_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="agent input",
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 2
|
||||
block_ids = {b.id for b in response.blocks}
|
||||
assert "input-block-id" in block_ids
|
||||
assert "output-block-id" in block_ids
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_tool_exposed_with_for_agent_generation_in_search(self):
|
||||
"""MCP_TOOL blocks appear in search results when for_agent_generation=True."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
search_results = [
|
||||
{"content_id": "mcp-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"mcp-block-id": mcp_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="mcp tool",
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 2
|
||||
assert any(b.id == "mcp-block-id" for b in response.blocks)
|
||||
assert any(b.id == "standard-block-id" for b in response.blocks)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_tool_excluded_without_for_agent_generation_in_search(self):
|
||||
"""MCP_TOOL blocks are excluded from search in normal CoPilot mode."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
search_results = [
|
||||
{"content_id": "mcp-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"mcp-block-id": mcp_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="mcp tool",
|
||||
for_agent_generation=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_for_agent_generation_exposes_excluded_ids_in_search(self):
|
||||
"""With for_agent_generation=True, excluded block IDs appear in search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
orchestrator_id = next(iter(COPILOT_EXCLUDED_BLOCK_IDS))
|
||||
|
||||
search_results = [
|
||||
{"content_id": orchestrator_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
orchestrator_block = make_mock_block(
|
||||
orchestrator_id, "Orchestrator", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
orchestrator_id: orchestrator_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="orchestrator",
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 2
|
||||
block_ids = {b.id for b in response.blocks}
|
||||
assert orchestrator_id in block_ids
|
||||
assert "normal-block-id" in block_ids
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_response_size_average_chars_per_block(self):
|
||||
"""Measure average chars per block in the serialized response."""
|
||||
@@ -549,8 +737,6 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -571,8 +757,6 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "disabled" in response.message.lower()
|
||||
|
||||
@@ -592,8 +776,6 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@@ -613,7 +795,74 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=orchestrator_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_excluded_block_type_allowed_with_for_agent_generation(
|
||||
self,
|
||||
):
|
||||
"""With for_agent_generation=True, excluded block types (INPUT) are visible."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "Agent Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query=block_id,
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.count == 1
|
||||
assert response.blocks[0].id == block_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_mcp_tool_exposed_with_for_agent_generation(self):
|
||||
"""MCP_TOOL blocks are returned by UUID lookup when for_agent_generation=True."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query=block_id,
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.blocks[0].id == block_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_mcp_tool_excluded_without_for_agent_generation(self):
|
||||
"""MCP_TOOL blocks are excluded by UUID lookup in normal CoPilot mode."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query=block_id,
|
||||
for_agent_generation=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "run_mcp_tool" in response.message
|
||||
|
||||
@@ -1179,6 +1179,7 @@ async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
target_tokens: int | None = None,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
@@ -1187,6 +1188,12 @@ async def _run_compression(
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
``target_tokens`` sets a hard token ceiling for the compressed output.
|
||||
When ``None``, ``compress_context`` derives the limit from the model's
|
||||
context window. Pass a smaller value on retries to force more aggressive
|
||||
compression — the compressor will LLM-summarize, content-truncate,
|
||||
middle-out delete, and first/last trim until the result fits.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
@@ -1196,18 +1203,27 @@ async def _run_compression(
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
compress_context(
|
||||
messages=messages, model=model, client=None, target_tokens=target_tokens
|
||||
),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
compress_context(
|
||||
messages=messages,
|
||||
model=model,
|
||||
client=client,
|
||||
target_tokens=target_tokens,
|
||||
),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
compress_context(
|
||||
messages=messages, model=model, client=None, target_tokens=target_tokens
|
||||
),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ class UserCreditBase(ABC):
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
from backend.executor.billing import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
@@ -554,7 +554,7 @@ class UserCreditBase(ABC):
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
from backend.executor.billing import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
|
||||
@@ -852,6 +852,7 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
cache_read_token_count: int = 0
|
||||
cache_creation_token_count: int = 0
|
||||
cost: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
provider_cost: float | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@ from prisma.models import User as PrismaUser
|
||||
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -139,7 +140,10 @@ class UserCostSummary(BaseModel):
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_cache_read_tokens: int = 0
|
||||
total_cache_creation_tokens: int = 0
|
||||
request_count: int
|
||||
cost_bearing_request_count: int = 0
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
@@ -161,12 +165,27 @@ class CostLogRow(BaseModel):
|
||||
cache_creation_tokens: int | None = None
|
||||
|
||||
|
||||
class CostBucket(BaseModel):
|
||||
bucket: str
|
||||
count: int
|
||||
|
||||
|
||||
class PlatformCostDashboard(BaseModel):
|
||||
by_provider: list[ProviderCostSummary]
|
||||
by_user: list[UserCostSummary]
|
||||
total_cost_microdollars: int
|
||||
total_requests: int
|
||||
total_users: int
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
avg_input_tokens_per_request: float = 0.0
|
||||
avg_output_tokens_per_request: float = 0.0
|
||||
avg_cost_microdollars_per_request: float = 0.0
|
||||
cost_p50_microdollars: float = 0.0
|
||||
cost_p75_microdollars: float = 0.0
|
||||
cost_p95_microdollars: float = 0.0
|
||||
cost_p99_microdollars: float = 0.0
|
||||
cost_buckets: list[CostBucket] = []
|
||||
|
||||
|
||||
def _si(row: dict, field: str) -> int:
|
||||
@@ -196,6 +215,7 @@ def _build_prisma_where(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> PlatformCostLogWhereInput:
|
||||
"""Build a Prisma WhereInput for PlatformCostLog filters."""
|
||||
where: PlatformCostLogWhereInput = {}
|
||||
@@ -223,9 +243,78 @@ def _build_prisma_where(
|
||||
if tracking_type:
|
||||
where["trackingType"] = tracking_type
|
||||
|
||||
if graph_exec_id:
|
||||
where["graphExecId"] = graph_exec_id
|
||||
|
||||
return where
|
||||
|
||||
|
||||
def _build_raw_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> tuple[str, list]:
|
||||
"""Build a parameterised WHERE clause for raw SQL queries.
|
||||
|
||||
Mirrors the filter logic of ``_build_prisma_where`` so there is a single
|
||||
source of truth for which columns are filtered and how. The first clause
|
||||
always restricts to ``cost_usd`` tracking type unless *tracking_type* is
|
||||
explicitly provided by the caller.
|
||||
"""
|
||||
params: list = []
|
||||
clauses: list[str] = []
|
||||
idx = 1
|
||||
|
||||
# Always filter by tracking type — defaults to cost_usd for percentile /
|
||||
# bucket queries that only make sense on cost-denominated rows.
|
||||
tt = tracking_type if tracking_type is not None else "cost_usd"
|
||||
clauses.append(f'"trackingType" = ${idx}')
|
||||
params.append(tt)
|
||||
idx += 1
|
||||
|
||||
if start is not None:
|
||||
clauses.append(f'"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
|
||||
if end is not None:
|
||||
clauses.append(f'"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
|
||||
if provider is not None:
|
||||
clauses.append(f'"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
|
||||
if user_id is not None:
|
||||
clauses.append(f'"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
if model is not None:
|
||||
clauses.append(f'"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
|
||||
if block_name is not None:
|
||||
clauses.append(f'LOWER("blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
|
||||
if graph_exec_id is not None:
|
||||
clauses.append(f'"graphExecId" = ${idx}')
|
||||
params.append(graph_exec_id)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses), params)
|
||||
|
||||
|
||||
@cached(ttl_seconds=30)
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
@@ -235,6 +324,7 @@ async def get_platform_cost_dashboard(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> PlatformCostDashboard:
|
||||
"""Aggregate platform cost logs for the admin dashboard.
|
||||
|
||||
@@ -251,7 +341,22 @@ async def get_platform_cost_dashboard(
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
|
||||
)
|
||||
|
||||
# For per-user tracking-type breakdown we intentionally omit the
|
||||
# tracking_type filter so cost_usd and tokens rows are always present.
|
||||
# This ensures cost_bearing_request_count is correct even when the caller
|
||||
# is filtering the main view by a different tracking_type.
|
||||
where_no_tracking_type = _build_prisma_where(
|
||||
start,
|
||||
end,
|
||||
provider,
|
||||
user_id,
|
||||
model,
|
||||
block_name,
|
||||
tracking_type=None,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
sum_fields = {
|
||||
@@ -264,39 +369,159 @@ async def get_platform_cost_dashboard(
|
||||
"trackingAmount": True,
|
||||
}
|
||||
|
||||
# Run all four aggregation queries in parallel.
|
||||
by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = (
|
||||
await asyncio.gather(
|
||||
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
|
||||
# sort by total cost descending in Python after fetch.
|
||||
# Build parameterised WHERE clause for the raw SQL percentile/bucket
|
||||
# queries. Uses _build_raw_where so filter logic is shared with
|
||||
# _build_prisma_where and only maintained in one place.
|
||||
# Always force tracking_type=None here so _build_raw_where defaults to
|
||||
# "cost_usd" — percentile and histogram queries only make sense on
|
||||
# cost-denominated rows, regardless of what the caller is filtering.
|
||||
raw_where, raw_params = _build_raw_where(
|
||||
start,
|
||||
end,
|
||||
provider,
|
||||
user_id,
|
||||
model,
|
||||
block_name,
|
||||
tracking_type=None,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
# Queries that always run regardless of tracking_type filter.
|
||||
common_queries = [
|
||||
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
|
||||
# sort by total cost descending in Python after fetch.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType", "model"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
# userId aggregation — emails fetched separately below.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
# Per-user cost-bearing request count: group by (userId, trackingType)
|
||||
# so we can compute the correct denominator for per-user avg cost.
|
||||
# Uses where_no_tracking_type so cost_usd rows are always included
|
||||
# even when the caller filters the main view by a different tracking_type.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
count=True,
|
||||
),
|
||||
# Distinct user count: group by userId, count groups.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Total aggregate (filtered): group by (provider, trackingType) so we can
|
||||
# compute cost-bearing and token-bearing denominators for avg stats.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType"],
|
||||
where=where,
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# Percentile distribution of cost per request (respects all filters).
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
" percentile_cont(0.5) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p50,'
|
||||
" percentile_cont(0.75) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p75,'
|
||||
" percentile_cont(0.95) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p95,'
|
||||
" percentile_cont(0.99) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p99'
|
||||
' FROM {schema_prefix}"PlatformCostLog"'
|
||||
f" WHERE {raw_where}",
|
||||
*raw_params,
|
||||
),
|
||||
# Histogram buckets for cost distribution (respects all filters).
|
||||
# NULL costMicrodollars is excluded explicitly to prevent such rows
|
||||
# from falling through all WHEN clauses into the ELSE '$10+' bucket.
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
" CASE"
|
||||
' WHEN "costMicrodollars" < 500000'
|
||||
" THEN '$0-0.50'"
|
||||
' WHEN "costMicrodollars" < 1000000'
|
||||
" THEN '$0.50-1'"
|
||||
' WHEN "costMicrodollars" < 2000000'
|
||||
" THEN '$1-2'"
|
||||
' WHEN "costMicrodollars" < 5000000'
|
||||
" THEN '$2-5'"
|
||||
' WHEN "costMicrodollars" < 10000000'
|
||||
" THEN '$5-10'"
|
||||
" ELSE '$10+'"
|
||||
" END as bucket,"
|
||||
" COUNT(*) as count"
|
||||
' FROM {schema_prefix}"PlatformCostLog"'
|
||||
f' WHERE {raw_where} AND "costMicrodollars" IS NOT NULL'
|
||||
" GROUP BY bucket"
|
||||
' ORDER BY MIN("costMicrodollars")',
|
||||
*raw_params,
|
||||
),
|
||||
]
|
||||
|
||||
# Only run the unfiltered aggregate query when tracking_type is set;
|
||||
# when tracking_type is None, the filtered query already contains all
|
||||
# tracking types and reusing it avoids a redundant full aggregation.
|
||||
if tracking_type is not None:
|
||||
common_queries.append(
|
||||
# Total aggregate (no tracking_type filter): used to compute
|
||||
# cost_bearing_requests and token_bearing_requests denominators so
|
||||
# global avg stats remain meaningful when the caller filters the
|
||||
# main view by a specific tracking_type (e.g. 'tokens').
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType", "model"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
by=["provider", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# userId aggregation — emails fetched separately below.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
# Distinct user count: group by userId, count groups.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Total aggregate: group by provider (no limit) to sum across all
|
||||
# matching rows. Summed in Python to get grand totals.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider"],
|
||||
where=where,
|
||||
sum={"costMicrodollars": True},
|
||||
count=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*common_queries)
|
||||
|
||||
# Unpack results by name for clarity.
|
||||
by_provider_groups = results[0]
|
||||
by_user_groups = results[1]
|
||||
by_user_tracking_groups = results[2]
|
||||
total_user_groups = results[3]
|
||||
total_agg_groups = results[4]
|
||||
percentile_rows = results[5]
|
||||
bucket_rows = results[6]
|
||||
# When tracking_type is None, the filtered and unfiltered queries are
|
||||
# identical — reuse total_agg_groups to avoid the extra DB round-trip.
|
||||
total_agg_no_tracking_type_groups = (
|
||||
results[7] if tracking_type is not None else total_agg_groups
|
||||
)
|
||||
|
||||
# Compute token grand-totals from the unfiltered aggregate so they remain
|
||||
# consistent with the avg-token stats (which also use unfiltered data).
|
||||
# Using by_provider_groups here would give 0 tokens when tracking_type='cost_usd'
|
||||
# because cost_usd rows carry no token data, contradicting non-zero averages.
|
||||
total_input_tokens = sum(
|
||||
_si(r, "inputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
total_output_tokens = sum(
|
||||
_si(r, "outputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
|
||||
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
|
||||
@@ -323,6 +548,61 @@ async def get_platform_cost_dashboard(
|
||||
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
|
||||
total_requests = sum(_ca(r) for r in total_agg_groups)
|
||||
|
||||
# Extract percentile values from the raw query result.
|
||||
pctl = percentile_rows[0] if percentile_rows else {}
|
||||
cost_p50 = float(pctl.get("p50") or 0)
|
||||
cost_p75 = float(pctl.get("p75") or 0)
|
||||
cost_p95 = float(pctl.get("p95") or 0)
|
||||
cost_p99 = float(pctl.get("p99") or 0)
|
||||
|
||||
# Build cost bucket list.
|
||||
cost_buckets: list[CostBucket] = [
|
||||
CostBucket(bucket=r["bucket"], count=int(r["count"])) for r in bucket_rows
|
||||
]
|
||||
|
||||
# Avg-stat numerators and denominators are derived from the unfiltered
|
||||
# aggregate so they remain meaningful when the caller filters by a specific
|
||||
# tracking_type. Example: filtering by 'tokens' excludes cost_usd rows from
|
||||
# total_agg_groups, so avg_cost would always be 0 if we used that; using
|
||||
# total_agg_no_tracking_type_groups gives the correct cost_usd total/count.
|
||||
avg_cost_total = sum(
|
||||
_si(r, "costMicrodollars")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "cost_usd"
|
||||
)
|
||||
cost_bearing_requests = sum(
|
||||
_ca(r)
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "cost_usd"
|
||||
)
|
||||
avg_input_total = sum(
|
||||
_si(r, "inputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
avg_output_total = sum(
|
||||
_si(r, "outputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
# Token-bearing request count: only rows where trackingType == "tokens".
|
||||
# Token averages must use this denominator; cost_usd rows do not carry tokens.
|
||||
token_bearing_requests = sum(
|
||||
_ca(r)
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
|
||||
# Per-user cost-bearing request count: used for per-user avg cost so the
|
||||
# denominator matches the numerator (cost_usd rows only, per user).
|
||||
user_cost_bearing_counts: dict[str, int] = {}
|
||||
for r in by_user_tracking_groups:
|
||||
if r.get("trackingType") == "cost_usd" and r.get("userId"):
|
||||
uid = r["userId"]
|
||||
user_cost_bearing_counts[uid] = user_cost_bearing_counts.get(uid, 0) + _ca(
|
||||
r
|
||||
)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
@@ -347,13 +627,38 @@ async def get_platform_cost_dashboard(
|
||||
total_cost_microdollars=_si(r, "costMicrodollars"),
|
||||
total_input_tokens=_si(r, "inputTokens"),
|
||||
total_output_tokens=_si(r, "outputTokens"),
|
||||
total_cache_read_tokens=_si(r, "cacheReadTokens"),
|
||||
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
|
||||
request_count=_ca(r),
|
||||
cost_bearing_request_count=user_cost_bearing_counts.get(
|
||||
r.get("userId") or "", 0
|
||||
),
|
||||
)
|
||||
for r in by_user_groups
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
total_users=total_users,
|
||||
total_input_tokens=total_input_tokens,
|
||||
total_output_tokens=total_output_tokens,
|
||||
avg_input_tokens_per_request=(
|
||||
avg_input_total / token_bearing_requests
|
||||
if token_bearing_requests > 0
|
||||
else 0.0
|
||||
),
|
||||
avg_output_tokens_per_request=(
|
||||
avg_output_total / token_bearing_requests
|
||||
if token_bearing_requests > 0
|
||||
else 0.0
|
||||
),
|
||||
avg_cost_microdollars_per_request=(
|
||||
avg_cost_total / cost_bearing_requests if cost_bearing_requests > 0 else 0.0
|
||||
),
|
||||
cost_p50_microdollars=cost_p50,
|
||||
cost_p75_microdollars=cost_p75,
|
||||
cost_p95_microdollars=cost_p95,
|
||||
cost_p99_microdollars=cost_p99,
|
||||
cost_buckets=cost_buckets,
|
||||
)
|
||||
|
||||
|
||||
@@ -367,12 +672,13 @@ async def get_platform_cost_logs(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
|
||||
)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
@@ -422,6 +728,7 @@ async def get_platform_cost_logs_for_export(
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
graph_exec_id: str | None = None,
|
||||
) -> tuple[list[CostLogRow], bool]:
|
||||
"""Return all matching rows up to EXPORT_MAX_ROWS.
|
||||
|
||||
@@ -432,7 +739,7 @@ async def get_platform_cost_logs_for_export(
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
start, end, provider, user_id, model, block_name, tracking_type, graph_exec_id
|
||||
)
|
||||
|
||||
rows = await PrismaLog.prisma().find_many(
|
||||
|
||||
@@ -10,6 +10,8 @@ from backend.util.json import SafeJson
|
||||
|
||||
from .platform_cost import (
|
||||
PlatformCostEntry,
|
||||
_build_prisma_where,
|
||||
_build_raw_where,
|
||||
_build_where,
|
||||
_mask_email,
|
||||
get_platform_cost_dashboard,
|
||||
@@ -156,6 +158,101 @@ class TestBuildWhere:
|
||||
assert 'p."trackingType" = $3' in sql
|
||||
|
||||
|
||||
class TestBuildPrismaWhere:
|
||||
def test_both_start_and_end(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
where = _build_prisma_where(start, end, None, None)
|
||||
assert where["createdAt"] == {"gte": start, "lte": end}
|
||||
|
||||
def test_end_only(self):
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
where = _build_prisma_where(None, end, None, None)
|
||||
assert where["createdAt"] == {"lte": end}
|
||||
|
||||
def test_start_only(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
where = _build_prisma_where(start, None, None, None)
|
||||
assert where["createdAt"] == {"gte": start}
|
||||
|
||||
def test_no_filters(self):
|
||||
where = _build_prisma_where(None, None, None, None)
|
||||
assert "createdAt" not in where
|
||||
|
||||
def test_provider_lowercased(self):
|
||||
where = _build_prisma_where(None, None, "OpenAI", None)
|
||||
assert where["provider"] == "openai"
|
||||
|
||||
def test_model_filter(self):
|
||||
where = _build_prisma_where(None, None, None, None, model="gpt-4")
|
||||
assert where["model"] == "gpt-4"
|
||||
|
||||
def test_block_name_case_insensitive(self):
|
||||
where = _build_prisma_where(None, None, None, None, block_name="LLMBlock")
|
||||
assert where["blockName"] == {"equals": "LLMBlock", "mode": "insensitive"}
|
||||
|
||||
def test_tracking_type(self):
|
||||
where = _build_prisma_where(None, None, None, None, tracking_type="tokens")
|
||||
assert where["trackingType"] == "tokens"
|
||||
|
||||
def test_graph_exec_id_filter(self):
|
||||
where = _build_prisma_where(None, None, None, None, graph_exec_id="exec-123")
|
||||
assert where["graphExecId"] == "exec-123"
|
||||
|
||||
def test_graph_exec_id_none_not_included(self):
|
||||
where = _build_prisma_where(None, None, None, None, graph_exec_id=None)
|
||||
assert "graphExecId" not in where
|
||||
|
||||
|
||||
class TestBuildRawWhere:
|
||||
def test_end_filter(self):
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_raw_where(None, end, None, None)
|
||||
assert '"createdAt" <= $2::timestamptz' in sql
|
||||
assert end in params
|
||||
|
||||
def test_model_filter(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, model="gpt-4")
|
||||
assert '"model" = $' in sql
|
||||
assert "gpt-4" in params
|
||||
|
||||
def test_block_name_filter(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, block_name="LLMBlock")
|
||||
assert 'LOWER("blockName") = LOWER($' in sql
|
||||
assert "LLMBlock" in params
|
||||
|
||||
def test_all_filters_combined(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_raw_where(
|
||||
start, end, "anthropic", "u1", model="claude-3", block_name="LLM"
|
||||
)
|
||||
# trackingType (default), start, end, provider, user_id, model, block_name
|
||||
assert len(params) == 7
|
||||
assert "anthropic" in params
|
||||
assert "u1" in params
|
||||
assert "claude-3" in params
|
||||
assert "LLM" in params
|
||||
|
||||
def test_default_tracking_type_is_cost_usd(self):
|
||||
sql, params = _build_raw_where(None, None, None, None)
|
||||
assert '"trackingType" = $1' in sql
|
||||
assert params[0] == "cost_usd"
|
||||
|
||||
def test_explicit_tracking_type_overrides_default(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, tracking_type="tokens")
|
||||
assert params[0] == "tokens"
|
||||
|
||||
def test_graph_exec_id_filter(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, graph_exec_id="exec-abc")
|
||||
assert '"graphExecId" = $' in sql
|
||||
assert "exec-abc" in params
|
||||
|
||||
def test_graph_exec_id_not_included_when_none(self):
|
||||
sql, params = _build_raw_where(None, None, None, None)
|
||||
assert "graphExecId" not in sql
|
||||
|
||||
|
||||
def _make_entry(**overrides: object) -> PlatformCostEntry:
|
||||
return PlatformCostEntry.model_validate(
|
||||
{
|
||||
@@ -286,8 +383,9 @@ class TestGetPlatformCostDashboard:
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups (no cost_usd rows for this user)
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
[provider_row], # total agg (tracking_type=None → same as unfiltered)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[mock_user])
|
||||
@@ -301,6 +399,14 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[
|
||||
[{"p50": 1000, "p75": 2000, "p95": 4000, "p99": 5000}],
|
||||
[{"bucket": "$0-0.50", "count": 3}],
|
||||
],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
@@ -313,6 +419,131 @@ class TestGetPlatformCostDashboard:
|
||||
assert dashboard.by_provider[0].total_duration_seconds == 10.5
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
assert dashboard.cost_p50_microdollars == 1000
|
||||
assert dashboard.cost_p75_microdollars == 2000
|
||||
assert dashboard.cost_p95_microdollars == 4000
|
||||
assert dashboard.cost_p99_microdollars == 5000
|
||||
assert len(dashboard.cost_buckets) == 1
|
||||
# total_input/output_tokens come from total_agg_no_tracking_type_groups
|
||||
# (provider_row has 1000/500)
|
||||
assert dashboard.total_input_tokens == 1000
|
||||
assert dashboard.total_output_tokens == 500
|
||||
# Token averages must use token_bearing_requests (3) not cost_bearing (0)
|
||||
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 3)
|
||||
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 3)
|
||||
# No cost_usd rows in total_agg → avg_cost should be 0
|
||||
assert dashboard.avg_cost_microdollars_per_request == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_bearing_request_count_nonzero_when_filtering_by_tokens(self):
|
||||
"""When filtering by tracking_type='tokens', cost_bearing_request_count
|
||||
must still reflect cost_usd rows because by_user_tracking_groups is
|
||||
queried without the tracking_type constraint."""
|
||||
# total_agg only has a tokens row (because of the tracking_type filter)
|
||||
total_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="tokens", cost=0, count=5
|
||||
)
|
||||
# by_user_tracking_groups returns BOTH rows (no tracking_type filter)
|
||||
user_tracking_cost_usd_row = {
|
||||
"_count": {"_all": 7},
|
||||
"userId": "u1",
|
||||
"trackingType": "cost_usd",
|
||||
}
|
||||
user_tracking_tokens_row = {
|
||||
"_count": {"_all": 5},
|
||||
"userId": "u1",
|
||||
"trackingType": "tokens",
|
||||
}
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[total_row], # by_provider
|
||||
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
|
||||
[
|
||||
user_tracking_cost_usd_row,
|
||||
user_tracking_tokens_row,
|
||||
], # by_user_tracking
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[total_row], # total agg (filtered)
|
||||
[total_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# by_user has 1 user with 5 total requests (tokens rows only due to filter)
|
||||
# but per-user cost_bearing count should be 7 (from cost_usd rows in
|
||||
# by_user_tracking_groups which uses where_no_tracking_type)
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].cost_bearing_request_count == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_avg_cost_nonzero_when_filtering_by_tokens(self):
|
||||
"""When filtering by tracking_type='tokens', avg_cost_microdollars_per_request
|
||||
must still reflect cost_usd rows from total_agg_no_tracking_type_groups,
|
||||
not the filtered total_agg_groups which only has tokens rows."""
|
||||
# filtered total_agg only has tokens rows (zero cost)
|
||||
tokens_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="tokens", cost=0, count=5
|
||||
)
|
||||
# unfiltered total_agg has both rows (cost_usd carries the actual cost)
|
||||
cost_usd_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="cost_usd", cost=10_000, count=4
|
||||
)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[tokens_row], # by_provider
|
||||
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[tokens_row], # total agg (filtered — tokens only)
|
||||
[tokens_row, cost_usd_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# avg_cost_microdollars_per_request must be non-zero: cost_usd row
|
||||
# (10_000 microdollars, 4 requests) is present in the unfiltered agg.
|
||||
assert dashboard.avg_cost_microdollars_per_request == pytest.approx(10_000 / 4)
|
||||
# avg token stats use token_bearing_requests from unfiltered agg (5)
|
||||
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 5)
|
||||
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_aggregated_not_hardcoded(self):
|
||||
@@ -335,8 +566,9 @@ class TestGetPlatformCostDashboard:
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u2"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
[provider_row], # total agg (tracking_type=None → same as unfiltered)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
@@ -350,6 +582,14 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[
|
||||
[{"p50": 0, "p75": 0, "p95": 0, "p99": 0}],
|
||||
[],
|
||||
],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
@@ -361,7 +601,7 @@ class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -373,6 +613,11 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
@@ -381,13 +626,56 @@ class TestGetPlatformCostDashboard:
|
||||
assert dashboard.total_users == 0
|
||||
assert dashboard.by_provider == []
|
||||
assert dashboard.by_user == []
|
||||
assert dashboard.cost_p50_microdollars == 0
|
||||
assert dashboard.cost_buckets == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
raw_mock = AsyncMock(side_effect=[[], []])
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
raw_mock,
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
|
||||
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users,
|
||||
# total agg filtered); the 6th call (total agg no-tracking-type) only runs
|
||||
# when tracking_type is set.
|
||||
assert mock_actions.group_by.await_count == 5
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
# Raw SQL queries should receive provider and user_id as parameters
|
||||
assert raw_mock.await_count == 2
|
||||
raw_call_args = raw_mock.call_args_list[0][0] # positional args of 1st call
|
||||
raw_params = raw_call_args[1:] # first arg is the query template
|
||||
assert "openai" in raw_params
|
||||
assert "u1" in raw_params
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_tracking_groups_excludes_tracking_type_filter(self):
|
||||
"""by_user_tracking_groups must NOT apply the tracking_type filter so that
|
||||
cost_usd rows are always included even when the caller filters by 'tokens'."""
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -399,16 +687,54 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# group_by called 4 times (by_provider, by_user, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 4
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
# Call index 2 is by_user_tracking_groups (0=by_provider, 1=by_user,
|
||||
# 2=by_user_tracking, 3=distinct_users, 4=total_agg, 5=total_agg_no_tt).
|
||||
tracking_call_where = mock_actions.group_by.call_args_list[2][1]["where"]
|
||||
# The main filter applies trackingType; by_user_tracking must NOT.
|
||||
assert "trackingType" not in tracking_call_where
|
||||
# Other filters (e.g., date range, provider) are still passed through.
|
||||
# The first call (by_provider) should have trackingType in its where dict.
|
||||
provider_call_where = mock_actions.group_by.call_args_list[0][1]["where"]
|
||||
assert "trackingType" in provider_call_where
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_exec_id_filter_passed_to_queries(self):
|
||||
"""graph_exec_id must be forwarded to both prisma and raw SQL queries."""
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
raw_mock = AsyncMock(side_effect=[[], []])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
raw_mock,
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(graph_exec_id="exec-xyz")
|
||||
|
||||
# Prisma groupBy where must include graphExecId
|
||||
first_call_where = mock_actions.group_by.call_args_list[0][1]["where"]
|
||||
assert first_call_where.get("graphExecId") == "exec-xyz"
|
||||
# Raw SQL params must include the exec id
|
||||
raw_params = raw_mock.call_args_list[0][0][1:]
|
||||
assert "exec-xyz" in raw_params
|
||||
|
||||
|
||||
def _make_prisma_log_row(
|
||||
@@ -509,6 +835,21 @@ class TestGetPlatformCostLogs:
|
||||
# start provided — should appear in the where filter
|
||||
assert "createdAt" in where
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_exec_id_filter(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=0)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(graph_exec_id="exec-abc")
|
||||
|
||||
where = mock_actions.count.call_args[1]["where"]
|
||||
assert where.get("graphExecId") == "exec-abc"
|
||||
|
||||
|
||||
class TestGetPlatformCostLogsForExport:
|
||||
@pytest.mark.asyncio
|
||||
@@ -594,6 +935,24 @@ class TestGetPlatformCostLogsForExport:
|
||||
assert logs[0].cache_read_tokens == 50
|
||||
assert logs[0].cache_creation_tokens == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_exec_id_filter(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export(
|
||||
graph_exec_id="exec-xyz"
|
||||
)
|
||||
|
||||
where = mock_actions.find_many.call_args[1]["where"]
|
||||
assert where.get("graphExecId") == "exec-xyz"
|
||||
assert logs == []
|
||||
assert truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_start_skips_default(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
509
autogpt_platform/backend/backend/executor/billing.py
Normal file
509
autogpt_platform/backend/backend/executor/billing.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import Block
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.clients import (
|
||||
get_database_manager_client,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .utils import LogMetadata, block_usage_cost, execution_usage_cost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerClient
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[Billing]")
|
||||
settings = Settings()
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
# Hard cap on the multiplier passed to charge_extra_runtime_cost to
|
||||
# protect against a corrupted llm_call_count draining a user's balance.
|
||||
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
|
||||
# 200 leaves headroom while preventing runaway charges.
|
||||
_MAX_EXTRA_RUNTIME_COST = 200
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def resolve_block_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> tuple["Block | None", int, dict[str, Any]]:
|
||||
"""Look up the block and compute its base usage cost for an exec.
|
||||
|
||||
Shared by charge_usage and charge_extra_runtime_cost so the
|
||||
(get_block, block_usage_cost) lookup lives in exactly one place.
|
||||
Returns ``(block, cost, matching_filter)``. ``block`` is ``None`` if
|
||||
the block id can't be resolved — callers should treat that as
|
||||
"nothing to charge".
|
||||
"""
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return None, 0, {}
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.inputs)
|
||||
return block, cost, matching_filter
|
||||
|
||||
|
||||
def charge_usage(
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block, cost, matching_filter = resolve_block_cost(node_exec)
|
||||
if not block:
|
||||
return total_cost, 0
|
||||
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
# execution_count=0 is used by charge_node_usage for nested tool calls
|
||||
# which must not be pushed into higher execution-count tiers.
|
||||
# execution_usage_cost(0) would trigger a charge because 0 % threshold == 0,
|
||||
# so skip it entirely when execution_count is 0.
|
||||
cost, usage_count = (
|
||||
execution_usage_cost(execution_count) if execution_count > 0 else (0, 0)
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": usage_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost, remaining_balance
|
||||
|
||||
|
||||
def _charge_extra_runtime_cost_sync(
|
||||
node_exec: NodeExecutionEntry,
|
||||
capped_count: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Synchronous implementation — runs in a thread-pool worker.
|
||||
|
||||
Called only from charge_extra_runtime_cost. Do not call directly from
|
||||
async code.
|
||||
|
||||
Note: ``resolve_block_cost`` is called again here (rather than reusing
|
||||
the result from ``charge_usage`` at the start of execution) because the
|
||||
two calls happen in separate thread-pool workers and sharing mutable
|
||||
state across workers would require locks. The block config is immutable
|
||||
during a run, so the repeated lookup is safe and produces the same cost;
|
||||
the only overhead is an extra registry lookup.
|
||||
"""
|
||||
db_client = get_db_client()
|
||||
block, cost, matching_filter = resolve_block_cost(node_exec)
|
||||
if not block or cost <= 0:
|
||||
return 0, 0
|
||||
total_extra_cost = cost * capped_count
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=total_extra_cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input={
|
||||
**matching_filter,
|
||||
"extra_runtime_cost_count": capped_count,
|
||||
},
|
||||
reason=(
|
||||
f"Extra agent-mode iterations for {block.name} "
|
||||
f"({capped_count} additional LLM calls)"
|
||||
),
|
||||
),
|
||||
)
|
||||
return total_extra_cost, remaining_balance
|
||||
|
||||
|
||||
async def charge_extra_runtime_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
extra_count: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Charge a block extra runtime cost beyond the initial run.
|
||||
|
||||
Used by agent-mode blocks (e.g. OrchestratorBlock) that make multiple
|
||||
LLM calls within a single node execution. The first iteration is already
|
||||
charged by charge_usage; this method charges *extra_count* additional
|
||||
copies of the block's base cost.
|
||||
|
||||
Returns ``(total_extra_cost, remaining_balance)``. May raise
|
||||
``InsufficientBalanceError`` if the user can't afford the charge.
|
||||
"""
|
||||
if extra_count <= 0:
|
||||
return 0, 0
|
||||
# Cap to protect against a corrupted llm_call_count.
|
||||
capped = min(extra_count, _MAX_EXTRA_RUNTIME_COST)
|
||||
if extra_count > _MAX_EXTRA_RUNTIME_COST:
|
||||
logger.warning(
|
||||
f"extra_count {extra_count} exceeds cap {_MAX_EXTRA_RUNTIME_COST};"
|
||||
f" charging {_MAX_EXTRA_RUNTIME_COST} (llm_call_count may be corrupted)"
|
||||
)
|
||||
return await asyncio.to_thread(_charge_extra_runtime_cost_sync, node_exec, capped)
|
||||
|
||||
|
||||
async def charge_node_usage(node_exec: NodeExecutionEntry) -> tuple[int, int]:
|
||||
"""Charge a single node execution to the user.
|
||||
|
||||
Public async wrapper around charge_usage for blocks (e.g. the
|
||||
OrchestratorBlock) that spawn nested node executions outside the main
|
||||
queue and therefore need to charge them explicitly.
|
||||
|
||||
Also handles low-balance notification so callers don't need to touch
|
||||
private functions directly.
|
||||
|
||||
Note: this **does not** increment the global execution counter
|
||||
(``increment_execution_count``). Nested tool executions are sub-steps
|
||||
of a single block run from the user's perspective and should not push
|
||||
them into higher per-execution cost tiers.
|
||||
"""
|
||||
|
||||
def _run():
|
||||
total_cost, remaining = charge_usage(node_exec, 0)
|
||||
if total_cost > 0:
|
||||
handle_low_balance(
|
||||
get_db_client(), node_exec.user_id, remaining, total_cost
|
||||
)
|
||||
return total_cost, remaining
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
|
||||
async def try_send_insufficient_funds_notif(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
error: InsufficientBalanceError,
|
||||
log_metadata: LogMetadata,
|
||||
) -> None:
|
||||
"""Send an insufficient-funds notification, swallowing failures."""
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
handle_insufficient_funds_notif,
|
||||
get_db_client(),
|
||||
user_id,
|
||||
graph_id,
|
||||
error,
|
||||
)
|
||||
except Exception as notif_error: # pragma: no cover
|
||||
log_metadata.warning(
|
||||
f"Failed to send insufficient funds notification: {notif_error}"
|
||||
)
|
||||
|
||||
|
||||
async def handle_post_execution_billing(
|
||||
node: Node,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats,
|
||||
status: ExecutionStatus,
|
||||
log_metadata: LogMetadata,
|
||||
) -> None:
|
||||
"""Charge extra runtime cost for blocks that opt into per-LLM-call billing.
|
||||
|
||||
The first LLM call is already covered by charge_usage(); each additional
|
||||
call costs another base_cost. Skipped for dry runs and failed runs.
|
||||
|
||||
InsufficientBalanceError here is a post-hoc billing leak: the work is
|
||||
already done but the user can no longer pay. The run stays COMPLETED and
|
||||
the error is logged with ``billing_leak: True`` for alerting.
|
||||
"""
|
||||
extra_iterations = (
|
||||
cast(Block, node.block).extra_runtime_cost(execution_stats)
|
||||
if status == ExecutionStatus.COMPLETED
|
||||
and not node_exec.execution_context.dry_run
|
||||
else 0
|
||||
)
|
||||
if extra_iterations <= 0:
|
||||
return
|
||||
|
||||
try:
|
||||
extra_cost, remaining_balance = await charge_extra_runtime_cost(
|
||||
node_exec,
|
||||
extra_iterations,
|
||||
)
|
||||
if extra_cost > 0:
|
||||
execution_stats.extra_cost += extra_cost
|
||||
await asyncio.to_thread(
|
||||
handle_low_balance,
|
||||
get_db_client(),
|
||||
node_exec.user_id,
|
||||
remaining_balance,
|
||||
extra_cost,
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
log_metadata.error(
|
||||
"billing_leak: insufficient balance after "
|
||||
f"{node.block.name} completed {extra_iterations} "
|
||||
f"extra iterations",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_runtime_cost_count": extra_iterations,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
# Do NOT set execution_stats.error — the node ran to completion,
|
||||
# only the post-hoc charge failed. See class-level billing-leak
|
||||
# contract documentation.
|
||||
await try_send_insufficient_funds_notif(
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
e,
|
||||
log_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
log_metadata.error(
|
||||
f"billing_leak: failed to charge extra iterations for {node.block.name}",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_runtime_cost_count": extra_iterations,
|
||||
"error_type": type(e).__name__,
|
||||
"error": str(e),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def handle_agent_run_notif(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
) -> None:
|
||||
metadata = db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
key: value[0] if key == "name" else value
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def handle_insufficient_funds_notif(
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
) -> None:
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send insufficient funds Discord alert: {alert_error}")
|
||||
|
||||
|
||||
def handle_low_balance(
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
) -> None:
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send low balance Discord alert: {e}")
|
||||
@@ -21,11 +21,9 @@ from sentry_sdk.api import get_current_scope as _sentry_get_current_scope
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockSchema
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
@@ -39,27 +37,18 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cost_tracking import (
|
||||
drain_pending_cost_logs,
|
||||
log_system_credential_cost,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
@@ -75,7 +64,6 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import (
|
||||
continuous_retry,
|
||||
@@ -84,6 +72,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import billing
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
@@ -98,9 +87,7 @@ from .utils import (
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
@@ -126,40 +113,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -681,12 +634,16 @@ class ExecutionProcessor:
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
|
||||
await billing.handle_post_execution_billing(
|
||||
node, node_exec, execution_stats, status, log_metadata
|
||||
)
|
||||
|
||||
graph_stats, graph_stats_lock = graph_stats_pair
|
||||
with graph_stats_lock:
|
||||
graph_stats.node_count += 1 + execution_stats.extra_steps
|
||||
graph_stats.nodes_cputime += execution_stats.cputime
|
||||
graph_stats.nodes_walltime += execution_stats.walltime
|
||||
graph_stats.cost += execution_stats.extra_cost
|
||||
graph_stats.cost += execution_stats.cost + execution_stats.extra_cost
|
||||
if isinstance(execution_stats.error, Exception):
|
||||
graph_stats.node_error_count += 1
|
||||
|
||||
@@ -716,6 +673,18 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
# If the node failed because a nested tool charge raised IBE,
|
||||
# send the user notification so they understand why the run stopped.
|
||||
if status == ExecutionStatus.FAILED and isinstance(
|
||||
execution_stats.error, InsufficientBalanceError
|
||||
):
|
||||
await billing.try_send_insufficient_funds_notif(
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
execution_stats.error,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
|
||||
@async_time_measured
|
||||
@@ -935,7 +904,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
finally:
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
billing.handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
@@ -944,57 +913,18 @@ class ExecutionProcessor:
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
def _charge_usage(
|
||||
async def charge_node_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost, 0
|
||||
return await billing.charge_node_usage(node_exec)
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": usage_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost, remaining_balance
|
||||
async def charge_extra_runtime_cost(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
extra_count: int,
|
||||
) -> tuple[int, int]:
|
||||
return await billing.charge_extra_runtime_cost(node_exec, extra_count)
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -1106,7 +1036,7 @@ class ExecutionProcessor:
|
||||
# Charge usage (may raise) — skipped for dry runs
|
||||
try:
|
||||
if not graph_exec.execution_context.dry_run:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
cost, remaining_balance = billing.charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(
|
||||
graph_exec.user_id
|
||||
@@ -1115,7 +1045,7 @@ class ExecutionProcessor:
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
@@ -1135,7 +1065,7 @@ class ExecutionProcessor:
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
@@ -1397,165 +1327,6 @@ class ExecutionProcessor:
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
key: value[0] if key == "name" else value
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(
|
||||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
||||
)
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
):
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send low balance Discord alert: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
|
||||
@@ -4,9 +4,9 @@ import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
from backend.executor import billing
|
||||
from backend.executor.billing import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -25,7 +25,6 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
@@ -36,13 +35,13 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
@@ -63,7 +62,7 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -99,7 +98,6 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
@@ -110,13 +108,13 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
@@ -134,7 +132,7 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -154,7 +152,6 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
@@ -166,12 +163,12 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
with patch("backend.executor.billing.queue_notification"), patch(
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -190,7 +187,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
@@ -198,7 +195,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
@@ -227,7 +224,7 @@ async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
with patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
@@ -263,7 +260,7 @@ async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestSe
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
with patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
@@ -290,7 +287,7 @@ async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
with patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
@@ -310,7 +307,6 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
@@ -321,13 +317,13 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -346,7 +342,7 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -370,7 +366,7 @@ async def test_add_transaction_clears_notifications_on_grant(server: SpinTestSer
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -412,7 +408,7 @@ async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestSe
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -450,7 +446,7 @@ async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -486,7 +482,7 @@ async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -521,7 +517,7 @@ async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
) as mock_query, patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
|
||||
@@ -4,26 +4,25 @@ import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import LowBalanceData
|
||||
from backend.executor.manager import ExecutionProcessor
|
||||
from backend.executor import billing
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
"""Test that _handle_low_balance triggers notification when crossing threshold."""
|
||||
"""Test that handle_low_balance triggers notification when crossing threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 400 # $4 - below $5 threshold
|
||||
transaction_cost = 600 # $6 transaction
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
@@ -37,7 +36,7 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
@@ -69,7 +68,6 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
):
|
||||
"""Test that no notification is sent when not crossing the threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 600 # $6 - above $5 threshold
|
||||
transaction_cost = (
|
||||
@@ -78,11 +76,11 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
@@ -94,7 +92,7 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
@@ -112,7 +110,6 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
):
|
||||
"""Test that no notification is sent when already below threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 300 # $3 - below $5 threshold
|
||||
transaction_cost = (
|
||||
@@ -121,11 +118,11 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
@@ -137,7 +134,7 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
|
||||
@@ -18,9 +18,13 @@ images: {
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import prisma.enums as prisma_enums
|
||||
import prisma.models as prisma_models
|
||||
from faker import Faker
|
||||
|
||||
# Import API functions from the backend
|
||||
@@ -30,10 +34,12 @@ from backend.api.features.store.db import (
|
||||
create_store_submission,
|
||||
review_store_submission,
|
||||
)
|
||||
from backend.api.features.store.model import StoreSubmission
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.data.auth.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.graph import Graph, Link, Node, create_graph, make_graph_model
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
@@ -60,6 +66,31 @@ MAX_REVIEWS_PER_VERSION = 5
|
||||
GUARANTEED_FEATURED_AGENTS = 8
|
||||
GUARANTEED_FEATURED_CREATORS = 5
|
||||
GUARANTEED_TOP_AGENTS = 10
|
||||
E2E_MARKETPLACE_CREATOR_EMAIL = "test123@example.com"
|
||||
E2E_MARKETPLACE_CREATOR_USERNAME = "e2e-marketplace"
|
||||
E2E_MARKETPLACE_AGENT_SLUG = "e2e-calculator-agent"
|
||||
E2E_MARKETPLACE_AGENT_NAME = "E2E Calculator Agent"
|
||||
E2E_MARKETPLACE_AGENT_INPUT_VALUE = 8
|
||||
E2E_MARKETPLACE_AGENT_OUTPUT_VALUE = 42
|
||||
_LOCAL_TEMPLATE_PATH = (
|
||||
Path(__file__).resolve().parents[1] / "agents" / "calculator-agent.json"
|
||||
)
|
||||
_DOCKER_TEMPLATE_PATH = Path(
|
||||
"/app/autogpt_platform/backend/agents/calculator-agent.json"
|
||||
)
|
||||
E2E_MARKETPLACE_AGENT_TEMPLATE_PATH = (
|
||||
_LOCAL_TEMPLATE_PATH if _LOCAL_TEMPLATE_PATH.exists() else _DOCKER_TEMPLATE_PATH
|
||||
)
|
||||
SEEDED_TEST_EMAILS = [
|
||||
"test123@example.com",
|
||||
"e2e.qa.auth@example.com",
|
||||
"e2e.qa.builder@example.com",
|
||||
"e2e.qa.library@example.com",
|
||||
"e2e.qa.marketplace@example.com",
|
||||
"e2e.qa.settings@example.com",
|
||||
"e2e.qa.parallel.a@example.com",
|
||||
"e2e.qa.parallel.b@example.com",
|
||||
]
|
||||
|
||||
|
||||
def get_image():
|
||||
@@ -100,6 +131,25 @@ def get_category():
|
||||
return random.choice(categories)
|
||||
|
||||
|
||||
def load_deterministic_marketplace_graph() -> Graph:
|
||||
graph = Graph.model_validate(
|
||||
json.loads(E2E_MARKETPLACE_AGENT_TEMPLATE_PATH.read_text())
|
||||
)
|
||||
graph.name = E2E_MARKETPLACE_AGENT_NAME
|
||||
graph.description = (
|
||||
"Deterministic marketplace calculator graph for Playwright PR E2E coverage."
|
||||
)
|
||||
|
||||
for node in graph.nodes:
|
||||
if (
|
||||
node.block_id == AgentInputBlock().id
|
||||
and node.input_default.get("value") is None
|
||||
):
|
||||
node.input_default["value"] = E2E_MARKETPLACE_AGENT_INPUT_VALUE
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
class TestDataCreator:
|
||||
"""Creates test data using API functions for E2E tests."""
|
||||
|
||||
@@ -123,9 +173,9 @@ class TestDataCreator:
|
||||
for i in range(NUM_USERS):
|
||||
try:
|
||||
# Generate test user data
|
||||
if i == 0:
|
||||
# First user should have test123@gmail.com email for testing
|
||||
email = "test123@gmail.com"
|
||||
if i < len(SEEDED_TEST_EMAILS):
|
||||
# Keep a deterministic pool for Playwright global setup and PR smoke flows
|
||||
email = SEEDED_TEST_EMAILS[i]
|
||||
else:
|
||||
email = faker.unique.email()
|
||||
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
|
||||
@@ -547,6 +597,46 @@ class TestDataCreator:
|
||||
print(f"Error updating profile {profile.id}: {e}")
|
||||
continue
|
||||
|
||||
deterministic_creator = next(
|
||||
(
|
||||
user
|
||||
for user in self.users
|
||||
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
|
||||
),
|
||||
None,
|
||||
)
|
||||
if deterministic_creator:
|
||||
deterministic_profile = next(
|
||||
(
|
||||
profile
|
||||
for profile in existing_profiles
|
||||
if profile.userId == deterministic_creator["id"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if deterministic_profile:
|
||||
try:
|
||||
updated_profile = await prisma.profile.update(
|
||||
where={"id": deterministic_profile.id},
|
||||
data={
|
||||
"name": "E2E Marketplace Creator",
|
||||
"username": E2E_MARKETPLACE_CREATOR_USERNAME,
|
||||
"description": "Deterministic marketplace creator for Playwright PR E2E coverage.",
|
||||
"links": ["https://example.com/e2e-marketplace"],
|
||||
"avatarUrl": get_image(),
|
||||
"isFeatured": True,
|
||||
},
|
||||
)
|
||||
profiles = [
|
||||
profile
|
||||
for profile in profiles
|
||||
if profile.get("id") != deterministic_profile.id
|
||||
]
|
||||
if updated_profile is not None:
|
||||
profiles.append(updated_profile.model_dump())
|
||||
except Exception as e:
|
||||
print(f"Error updating deterministic E2E creator profile: {e}")
|
||||
|
||||
self.profiles = profiles
|
||||
return profiles
|
||||
|
||||
@@ -562,58 +652,184 @@ class TestDataCreator:
|
||||
featured_count = 0
|
||||
submission_counter = 0
|
||||
|
||||
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||
# Create a deterministic calculator marketplace agent for PR E2E coverage
|
||||
test_user = next(
|
||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||
(
|
||||
user
|
||||
for user in self.users
|
||||
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
|
||||
),
|
||||
None,
|
||||
)
|
||||
if test_user and self.agent_graphs:
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"graph_id": self.agent_graphs[0]["id"],
|
||||
"graph_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
"sub_heading": "A test agent for frontend testing",
|
||||
"video_url": "https://www.youtube.com/watch?v=test123",
|
||||
"image_urls": [
|
||||
"https://picsum.photos/200/300",
|
||||
"https://picsum.photos/200/301",
|
||||
"https://picsum.photos/200/302",
|
||||
],
|
||||
"description": "This is a test agent submission specifically created for frontend testing purposes.",
|
||||
"categories": ["test", "demo", "frontend"],
|
||||
"changes_summary": "Initial test submission",
|
||||
}
|
||||
if test_user:
|
||||
deterministic_graph = None
|
||||
|
||||
try:
|
||||
test_submission = await create_store_submission(**test_submission_data)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# ALWAYS approve and feature the test submission
|
||||
if test_submission.listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
existing_graph = await prisma_models.AgentGraph.prisma().find_first(
|
||||
where={
|
||||
"userId": test_user["id"],
|
||||
"name": E2E_MARKETPLACE_AGENT_NAME,
|
||||
"isActive": True,
|
||||
},
|
||||
order={"version": "desc"},
|
||||
)
|
||||
if existing_graph:
|
||||
deterministic_graph = {
|
||||
"id": existing_graph.id,
|
||||
"version": existing_graph.version,
|
||||
"name": existing_graph.name,
|
||||
"userId": test_user["id"],
|
||||
}
|
||||
self.agent_graphs.append(deterministic_graph)
|
||||
print(
|
||||
"✅ Reused existing deterministic marketplace graph: "
|
||||
f"{existing_graph.id}"
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
else:
|
||||
deterministic_graph_model = make_graph_model(
|
||||
load_deterministic_marketplace_graph(),
|
||||
test_user["id"],
|
||||
)
|
||||
featured_count += 1
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
|
||||
deterministic_graph_model.reassign_ids(
|
||||
user_id=test_user["id"],
|
||||
reassign_graph_id=True,
|
||||
)
|
||||
created_deterministic_graph = await create_graph(
|
||||
deterministic_graph_model,
|
||||
test_user["id"],
|
||||
)
|
||||
deterministic_graph = created_deterministic_graph.model_dump()
|
||||
deterministic_graph["userId"] = test_user["id"]
|
||||
self.agent_graphs.append(deterministic_graph)
|
||||
print("✅ Created deterministic marketplace graph")
|
||||
except Exception as e:
|
||||
print(f"Error creating test store submission: {e}")
|
||||
import traceback
|
||||
print(f"Error creating deterministic marketplace graph: {e}")
|
||||
|
||||
traceback.print_exc()
|
||||
if deterministic_graph is None and self.agent_graphs:
|
||||
test_user_graphs = [
|
||||
graph
|
||||
for graph in self.agent_graphs
|
||||
if graph.get("userId") == test_user["id"]
|
||||
]
|
||||
deterministic_graph = next(
|
||||
(
|
||||
graph
|
||||
for graph in test_user_graphs
|
||||
if not graph.get("name", "").startswith("DummyInput ")
|
||||
),
|
||||
test_user_graphs[0] if test_user_graphs else None,
|
||||
)
|
||||
|
||||
if deterministic_graph:
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"graph_id": deterministic_graph["id"],
|
||||
"graph_version": deterministic_graph.get("version", 1),
|
||||
"slug": E2E_MARKETPLACE_AGENT_SLUG,
|
||||
"name": E2E_MARKETPLACE_AGENT_NAME,
|
||||
"sub_heading": "A deterministic calculator agent for PR E2E coverage",
|
||||
"video_url": "https://www.youtube.com/watch?v=test123",
|
||||
"image_urls": [
|
||||
"https://picsum.photos/seed/e2e-marketplace-1/200/300",
|
||||
"https://picsum.photos/seed/e2e-marketplace-2/200/301",
|
||||
"https://picsum.photos/seed/e2e-marketplace-3/200/302",
|
||||
],
|
||||
"description": (
|
||||
"A deterministic marketplace calculator agent that adds "
|
||||
f"{E2E_MARKETPLACE_AGENT_INPUT_VALUE} and 34 to produce "
|
||||
f"{E2E_MARKETPLACE_AGENT_OUTPUT_VALUE} for frontend E2E coverage."
|
||||
),
|
||||
"categories": ["test", "demo", "frontend"],
|
||||
"changes_summary": (
|
||||
"Initial deterministic calculator submission seeded from "
|
||||
"backend/agents/calculator-agent.json"
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
existing_deterministic_submission = (
|
||||
await prisma_models.StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"isDeleted": False,
|
||||
"StoreListing": {
|
||||
"is": {
|
||||
"owningUserId": test_user["id"],
|
||||
"slug": E2E_MARKETPLACE_AGENT_SLUG,
|
||||
"isDeleted": False,
|
||||
}
|
||||
},
|
||||
},
|
||||
include={"StoreListing": True},
|
||||
order={"version": "desc"},
|
||||
)
|
||||
)
|
||||
|
||||
if existing_deterministic_submission:
|
||||
test_submission = StoreSubmission.from_listing_version(
|
||||
existing_deterministic_submission
|
||||
)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print(
|
||||
"✅ Reused deterministic marketplace submission: "
|
||||
f"{E2E_MARKETPLACE_AGENT_NAME}"
|
||||
)
|
||||
else:
|
||||
test_submission = await create_store_submission(
|
||||
**test_submission_data
|
||||
)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print(
|
||||
"✅ Created deterministic marketplace submission: "
|
||||
f"{E2E_MARKETPLACE_AGENT_NAME}"
|
||||
)
|
||||
|
||||
current_status = (
|
||||
existing_deterministic_submission.submissionStatus
|
||||
if existing_deterministic_submission
|
||||
else test_submission.status
|
||||
)
|
||||
is_featured = bool(
|
||||
existing_deterministic_submission
|
||||
and existing_deterministic_submission.isFeatured
|
||||
)
|
||||
|
||||
if test_submission.listing_version_id:
|
||||
if current_status != prisma_enums.SubmissionStatus.APPROVED:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Deterministic calculator submission approved",
|
||||
internal_comments="Auto-approved PR E2E marketplace submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(
|
||||
approved_submission.model_dump()
|
||||
)
|
||||
print("✅ Approved deterministic marketplace submission")
|
||||
else:
|
||||
approved_submissions.append(test_submission.model_dump())
|
||||
print(
|
||||
"✅ Deterministic marketplace submission already approved"
|
||||
)
|
||||
|
||||
if is_featured:
|
||||
featured_count += 1
|
||||
print("🌟 Deterministic marketplace agent already FEATURED")
|
||||
else:
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print(
|
||||
"🌟 Marked deterministic marketplace agent as FEATURED"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating deterministic marketplace submission: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Create regular submissions for all users
|
||||
for user in self.users:
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
# 5. CLI arguments - docker compose run -e VAR=value
|
||||
|
||||
# Common backend environment - Docker service names
|
||||
x-backend-env: &backend-env # Docker internal service hostnames (override localhost defaults)
|
||||
x-backend-env:
|
||||
&backend-env # Docker internal service hostnames (override localhost defaults)
|
||||
PYRO_HOST: "0.0.0.0"
|
||||
AGENTSERVER_HOST: rest_server
|
||||
SCHEDULER_HOST: scheduler_server
|
||||
@@ -39,7 +40,12 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: migrate
|
||||
command: ["sh", "-c", "prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy"]
|
||||
command:
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
"prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy",
|
||||
]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
@@ -79,8 +85,8 @@ services:
|
||||
falkordb:
|
||||
image: falkordb/falkordb:latest
|
||||
ports:
|
||||
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
|
||||
- "3001:3000" # FalkorDB web UI
|
||||
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
|
||||
- "3001:3000" # FalkorDB web UI
|
||||
environment:
|
||||
- REDIS_ARGS=--requirepass ${GRAPHITI_FALKORDB_PASSWORD:-}
|
||||
volumes:
|
||||
@@ -88,7 +94,11 @@ services:
|
||||
networks:
|
||||
- app-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "redis-cli -p 6379 -a \"${GRAPHITI_FALKORDB_PASSWORD:-}\" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1"]
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
'redis-cli -p 6379 -a "${GRAPHITI_FALKORDB_PASSWORD:-}" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1',
|
||||
]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
@@ -300,19 +310,6 @@ services:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
# healthcheck:
|
||||
# test:
|
||||
# [
|
||||
# "CMD",
|
||||
# "curl",
|
||||
# "-f",
|
||||
# "-X",
|
||||
# "POST",
|
||||
# "http://localhost:8003/health_check",
|
||||
# ]
|
||||
# interval: 10s
|
||||
# timeout: 10s
|
||||
# retries: 5
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
|
||||
@@ -193,3 +193,4 @@ services:
|
||||
- copilot_executor
|
||||
- websocket_server
|
||||
- database_manager
|
||||
- scheduler_server
|
||||
|
||||
@@ -8,6 +8,7 @@ const config: StorybookConfig = {
|
||||
"../src/components/molecules/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/ai-elements/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/renderers/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/app/[(]platform[)]/copilot/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
],
|
||||
addons: [
|
||||
"@storybook/addon-a11y",
|
||||
|
||||
@@ -81,8 +81,10 @@ Every time a new Front-end dependency is added by you or others, you will need t
|
||||
- `pnpm lint` - Run ESLint and Prettier checks
|
||||
- `pnpm format` - Format code with Prettier
|
||||
- `pnpm types` - Run TypeScript type checking
|
||||
- `pnpm test` - Run Playwright tests
|
||||
- `pnpm test-ui` - Run Playwright tests with UI
|
||||
- `pnpm test:unit` - Run the Vitest integration and unit suite with coverage
|
||||
- `pnpm test` - Run the Playwright E2E suite used in CI
|
||||
- `pnpm test-ui` - Run the same Playwright E2E suite with UI
|
||||
- `pnpm test:e2e:no-build` - Run the same Playwright E2E suite against a running app
|
||||
- `pnpm fetch:openapi` - Fetch OpenAPI spec from backend
|
||||
- `pnpm generate:api-client` - Generate API client from OpenAPI spec
|
||||
- `pnpm generate:api` - Fetch OpenAPI spec and generate API client
|
||||
|
||||
@@ -121,35 +121,49 @@ Only when the component has complex internal logic that is hard to exercise thro
|
||||
### Running
|
||||
|
||||
```bash
|
||||
pnpm test # build + run all Playwright tests
|
||||
pnpm test-ui # run with Playwright UI
|
||||
pnpm test:no-build # run against a running dev server
|
||||
pnpm test # build + run the Playwright E2E suite used in CI
|
||||
pnpm test-ui # run the same E2E suite with Playwright UI
|
||||
pnpm test:e2e:no-build # run the same E2E suite against a running dev server
|
||||
pnpm exec playwright test # run the same eight-spec Playwright suite directly
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
1. Start the backend + Supabase stack:
|
||||
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
|
||||
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
|
||||
2. Seed rich E2E data (creates `test123@example.com` with library agents):
|
||||
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
|
||||
|
||||
### How Playwright setup works
|
||||
|
||||
- Playwright runs from `frontend/playwright.config.ts` with a global setup step
|
||||
- Global setup creates a user pool via the real signup UI, stored in `frontend/.auth/user-pool.json`
|
||||
- `getTestUser()` (from `src/tests/utils/auth.ts`) pulls a random user from the pool
|
||||
- Playwright runs from `frontend/playwright.config.ts` and keeps browser-only code in `frontend/src/playwright/`
|
||||
- Global setup creates reusable auth states for deterministic seeded accounts in `frontend/.auth/states/`
|
||||
- `getTestUser()` (from `src/playwright/utils/auth.ts`) picks one seeded account for general auth coverage
|
||||
- `getTestUserWithLibraryAgents()` uses the rich user created by the data script
|
||||
|
||||
### Test users
|
||||
|
||||
- **User pool (basic users)** — created automatically by Playwright global setup. Used by `getTestUser()`
|
||||
- **Seeded E2E accounts** — created by backend fixtures and logged in during Playwright global setup. Used by `getTestUser()` and `E2E_AUTH_STATES`
|
||||
- **Rich user with library agents** — created by `backend/test/e2e_test_data.py`. Used by `getTestUserWithLibraryAgents()`
|
||||
|
||||
### Current Playwright E2E suite
|
||||
|
||||
The CI suite is intentionally limited to the cross-page journeys we still require a real browser for. Playwright discovers the PR-gating specs by the `*-happy-path.spec.ts` naming pattern inside `src/playwright/`:
|
||||
|
||||
- `src/playwright/auth-happy-path.spec.ts`
|
||||
- `src/playwright/settings-happy-path.spec.ts`
|
||||
- `src/playwright/api-keys-happy-path.spec.ts`
|
||||
- `src/playwright/builder-happy-path.spec.ts`
|
||||
- `src/playwright/library-happy-path.spec.ts`
|
||||
- `src/playwright/marketplace-happy-path.spec.ts`
|
||||
- `src/playwright/publish-happy-path.spec.ts`
|
||||
- `src/playwright/copilot-happy-path.spec.ts`
|
||||
|
||||
### Resetting the DB
|
||||
|
||||
If you reset the Docker DB and logins start failing:
|
||||
|
||||
1. Delete `frontend/.auth/user-pool.json`
|
||||
1. Delete `frontend/.auth/states/*` and `frontend/.auth/user-pool.json` if it exists
|
||||
2. Re-run `poetry run python test/e2e_test_data.py`
|
||||
|
||||
## Storybook
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
"lint": "next lint && prettier --check .",
|
||||
"format": "next lint --fix; prettier --write .",
|
||||
"types": "tsc --noEmit",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:ui",
|
||||
"test:unit": "vitest run --coverage",
|
||||
"test:unit:watch": "vitest",
|
||||
"test:no-build": "playwright test",
|
||||
"test:e2e": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
|
||||
"test:e2e:no-build": "playwright test",
|
||||
"test:e2e:ui": "playwright test --ui",
|
||||
"gentests": "playwright codegen http://localhost:3000",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
|
||||
@@ -7,10 +7,22 @@ import { defineConfig, devices } from "@playwright/test";
|
||||
import dotenv from "dotenv";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { buildCookieConsentStorageState } from "./src/playwright/credentials/storage-state";
|
||||
dotenv.config({ path: path.resolve(__dirname, ".env") });
|
||||
dotenv.config({ path: path.resolve(__dirname, "../backend/.env") });
|
||||
|
||||
const frontendRoot = __dirname.replaceAll("\\", "/");
|
||||
const configuredBaseURL =
|
||||
process.env.PLAYWRIGHT_BASE_URL ?? "http://localhost:3000";
|
||||
const parsedBaseURL = new URL(configuredBaseURL);
|
||||
const baseURL = parsedBaseURL.toString().replace(/\/$/, "");
|
||||
const baseOrigin = parsedBaseURL.origin;
|
||||
const jsonReporterOutputFile = process.env.PLAYWRIGHT_JSON_OUTPUT_FILE;
|
||||
const configuredWorkers = process.env.PLAYWRIGHT_WORKERS
|
||||
? Number(process.env.PLAYWRIGHT_WORKERS)
|
||||
: process.env.CI
|
||||
? 8
|
||||
: undefined;
|
||||
|
||||
// Directory where CI copies .next/static from the Docker container
|
||||
const staticCoverageDir = path.resolve(__dirname, ".next-static-coverage");
|
||||
@@ -57,17 +69,18 @@ function resolveSourceMap(sourcePath: string) {
|
||||
}
|
||||
|
||||
export default defineConfig({
|
||||
testDir: "./src/tests",
|
||||
testDir: "./src/playwright",
|
||||
testMatch: /.*-happy-path\.spec\.ts/,
|
||||
/* Global setup file that runs before all tests */
|
||||
globalSetup: "./src/tests/global-setup.ts",
|
||||
globalSetup: "./src/playwright/global-setup.ts",
|
||||
/* Run tests in files in parallel */
|
||||
fullyParallel: true,
|
||||
/* Fail the build on CI if you accidentally left test.only in the source code. */
|
||||
forbidOnly: !!process.env.CI,
|
||||
/* Retry on CI only */
|
||||
retries: process.env.CI ? 1 : 0,
|
||||
/* use more workers on CI. */
|
||||
workers: process.env.CI ? 4 : undefined,
|
||||
retries: process.env.CI ? Number(process.env.PLAYWRIGHT_RETRIES ?? 2) : 0,
|
||||
/* Higher worker count keeps PR smoke runtime down without sharing page state. */
|
||||
workers: configuredWorkers,
|
||||
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
|
||||
reporter: [
|
||||
["list"],
|
||||
@@ -92,40 +105,25 @@ export default defineConfig({
|
||||
},
|
||||
},
|
||||
],
|
||||
...(jsonReporterOutputFile
|
||||
? [["json", { outputFile: jsonReporterOutputFile }] as const]
|
||||
: []),
|
||||
],
|
||||
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
|
||||
use: {
|
||||
/* Base URL to use in actions like `await page.goto('/')`. */
|
||||
baseURL: "http://localhost:3000/",
|
||||
baseURL,
|
||||
|
||||
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
|
||||
screenshot: "only-on-failure",
|
||||
bypassCSP: true,
|
||||
|
||||
/* Helps debugging failures */
|
||||
trace: "retain-on-failure",
|
||||
video: "retain-on-failure",
|
||||
trace: process.env.CI ? "on-first-retry" : "retain-on-failure",
|
||||
video: process.env.CI ? "off" : "retain-on-failure",
|
||||
|
||||
/* Auto-accept cookies in all tests to prevent banner interference */
|
||||
storageState: {
|
||||
cookies: [],
|
||||
origins: [
|
||||
{
|
||||
origin: "http://localhost:3000",
|
||||
localStorage: [
|
||||
{
|
||||
name: "autogpt_cookie_consent",
|
||||
value: JSON.stringify({
|
||||
hasConsented: true,
|
||||
timestamp: Date.now(),
|
||||
analytics: true,
|
||||
monitoring: true,
|
||||
}),
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
storageState: buildCookieConsentStorageState(baseOrigin),
|
||||
},
|
||||
/* Maximum time one test can run for */
|
||||
timeout: 25000,
|
||||
@@ -133,7 +131,7 @@ export default defineConfig({
|
||||
/* Configure web server to start automatically (local dev only) */
|
||||
webServer: {
|
||||
command: "pnpm start",
|
||||
url: "http://localhost:3000",
|
||||
url: baseURL,
|
||||
reuseExistingServer: true,
|
||||
},
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
screen,
|
||||
cleanup,
|
||||
waitFor,
|
||||
fireEvent,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { PlatformCostContent } from "../components/PlatformCostContent";
|
||||
@@ -29,6 +30,16 @@ const emptyDashboard: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 0,
|
||||
total_requests: 0,
|
||||
total_users: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
avg_input_tokens_per_request: 0,
|
||||
avg_output_tokens_per_request: 0,
|
||||
avg_cost_microdollars_per_request: 0,
|
||||
cost_p50_microdollars: 0,
|
||||
cost_p75_microdollars: 0,
|
||||
cost_p95_microdollars: 0,
|
||||
cost_p99_microdollars: 0,
|
||||
cost_buckets: [],
|
||||
by_provider: [],
|
||||
by_user: [],
|
||||
};
|
||||
@@ -47,6 +58,20 @@ const dashboardWithData: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 5_000_000,
|
||||
total_requests: 100,
|
||||
total_users: 5,
|
||||
total_input_tokens: 150000,
|
||||
total_output_tokens: 60000,
|
||||
avg_input_tokens_per_request: 2500,
|
||||
avg_output_tokens_per_request: 1000,
|
||||
avg_cost_microdollars_per_request: 83333,
|
||||
cost_p50_microdollars: 50000,
|
||||
cost_p75_microdollars: 100000,
|
||||
cost_p95_microdollars: 250000,
|
||||
cost_p99_microdollars: 500000,
|
||||
cost_buckets: [
|
||||
{ bucket: "$0-0.50", count: 80 },
|
||||
{ bucket: "$0.50-1", count: 15 },
|
||||
{ bucket: "$1-2", count: 5 },
|
||||
],
|
||||
by_provider: [
|
||||
{
|
||||
provider: "openai",
|
||||
@@ -75,6 +100,7 @@ const dashboardWithData: PlatformCostDashboard = {
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
request_count: 60,
|
||||
cost_bearing_request_count: 40,
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -134,9 +160,14 @@ describe("PlatformCostContent", () => {
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
|
||||
// Known Cost and Estimated Total cards render $0.0000
|
||||
// "Known Cost" appears in both the SummaryCard and the ProviderTable header
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
// All cost summary cards (Known Cost, Estimated Total, Avg Cost,
|
||||
// Typical/Upper/High/Peak Cost) show $0.0000
|
||||
const zeroCostItems = screen.getAllByText("$0.0000");
|
||||
expect(zeroCostItems.length).toBe(2);
|
||||
expect(zeroCostItems.length).toBe(7);
|
||||
expect(screen.getByText("No cost data yet")).toBeDefined();
|
||||
});
|
||||
|
||||
@@ -155,7 +186,9 @@ describe("PlatformCostContent", () => {
|
||||
);
|
||||
expect(screen.getByText("$5.0000")).toBeDefined();
|
||||
expect(screen.getByText("100")).toBeDefined();
|
||||
expect(screen.getByText("5")).toBeDefined();
|
||||
// "5" appears in multiple places (Active Users card + bucket count),
|
||||
// so verify at least one element renders it.
|
||||
expect(screen.getAllByText("5").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("openai")).toBeDefined();
|
||||
expect(screen.getByText("google_maps")).toBeDefined();
|
||||
});
|
||||
@@ -223,10 +256,83 @@ describe("PlatformCostContent", () => {
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Original 4 cards
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
expect(screen.getByText("Total Requests")).toBeDefined();
|
||||
expect(screen.getByText("Active Users")).toBeDefined();
|
||||
// New average/token cards
|
||||
expect(screen.getByText("Avg Cost / Request")).toBeDefined();
|
||||
expect(screen.getByText("Avg Input Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Avg Output Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Total Tokens")).toBeDefined();
|
||||
// Percentile cards (friendlier labels)
|
||||
expect(screen.getByText("Typical Cost (P50)")).toBeDefined();
|
||||
expect(screen.getByText("Upper Cost (P75)")).toBeDefined();
|
||||
expect(screen.getByText("High Cost (P95)")).toBeDefined();
|
||||
expect(screen.getByText("Peak Cost (P99)")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders cost distribution buckets", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Cost Distribution by Bucket")).toBeDefined();
|
||||
expect(screen.getByText("$0-0.50")).toBeDefined();
|
||||
expect(screen.getByText("$0.50-1")).toBeDefined();
|
||||
expect(screen.getByText("$1-2")).toBeDefined();
|
||||
expect(screen.getByText("80")).toBeDefined();
|
||||
expect(screen.getByText("15")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders new summary card values from fixture data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Avg Input Tokens: 2500 formatted
|
||||
expect(screen.getByText("2,500")).toBeDefined();
|
||||
// Avg Output Tokens: 1000 formatted
|
||||
expect(screen.getByText("1,000")).toBeDefined();
|
||||
// P50 cost: 50000 microdollars = $0.0500
|
||||
expect(screen.getByText("$0.0500")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders user table avg cost column with fixture data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// User table should show Avg Cost / Req header
|
||||
expect(screen.getByText("Avg Cost / Req")).toBeDefined();
|
||||
// Input/Output token columns
|
||||
expect(screen.getByText("Input Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Output Tokens")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders filter inputs", async () => {
|
||||
@@ -246,6 +352,95 @@ describe("PlatformCostContent", () => {
|
||||
expect(screen.getByText("Apply")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders execution ID filter input", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Execution ID")).toBeDefined();
|
||||
expect(screen.getByPlaceholderText("Filter by execution")).toBeDefined();
|
||||
});
|
||||
|
||||
it("pre-fills execution ID filter from searchParams", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ graph_exec_id: "exec-123" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
const input = screen.getByPlaceholderText(
|
||||
"Filter by execution",
|
||||
) as HTMLInputElement;
|
||||
expect(input.value).toBe("exec-123");
|
||||
});
|
||||
|
||||
it("clears execution ID input on Clear click", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ graph_exec_id: "exec-123" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
fireEvent.click(screen.getByText("Clear"));
|
||||
const input = screen.getByPlaceholderText(
|
||||
"Filter by execution",
|
||||
) as HTMLInputElement;
|
||||
expect(input.value).toBe("");
|
||||
});
|
||||
|
||||
it("passes execution ID to filter on Apply click", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
const input = screen.getByPlaceholderText(
|
||||
"Filter by execution",
|
||||
) as HTMLInputElement;
|
||||
fireEvent.change(input, { target: { value: "exec-abc" } });
|
||||
expect(input.value).toBe("exec-abc");
|
||||
fireEvent.click(screen.getByText("Apply"));
|
||||
// After apply, the input still holds the typed value
|
||||
expect(input.value).toBe("exec-abc");
|
||||
});
|
||||
|
||||
it("copies execution ID to clipboard on cell click in logs tab", async () => {
|
||||
const writeText = vi.fn().mockResolvedValue(undefined);
|
||||
vi.stubGlobal("navigator", { ...navigator, clipboard: { writeText } });
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// The exec ID cell shows first 8 chars of "gx-123"
|
||||
const execIdCell = screen.getByText("gx-123".slice(0, 8));
|
||||
fireEvent.click(execIdCell);
|
||||
expect(writeText).toHaveBeenCalledWith("gx-123");
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("renders by-user tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
|
||||
@@ -67,7 +67,10 @@ function LogsTable({
|
||||
Cost
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Tokens
|
||||
In / Out
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Cache (R/W)
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Duration
|
||||
@@ -105,12 +108,34 @@ function LogsTable({
|
||||
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.cache_read_tokens || log.cache_creation_tokens
|
||||
? `${formatTokens(Number(log.cache_read_tokens ?? 0))} / ${formatTokens(Number(log.cache_creation_tokens ?? 0))}`
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.duration != null
|
||||
? formatDuration(Number(log.duration))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs text-muted-foreground">
|
||||
<td
|
||||
className={[
|
||||
"px-3 py-2 text-xs text-muted-foreground",
|
||||
log.graph_exec_id ? "cursor-pointer" : "",
|
||||
].join(" ")}
|
||||
title={
|
||||
log.graph_exec_id ? String(log.graph_exec_id) : undefined
|
||||
}
|
||||
onClick={
|
||||
log.graph_exec_id
|
||||
? () => {
|
||||
navigator.clipboard
|
||||
.writeText(String(log.graph_exec_id))
|
||||
.catch(() => {});
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{log.graph_exec_id
|
||||
? String(log.graph_exec_id).slice(0, 8)
|
||||
: "-"}
|
||||
@@ -120,7 +145,7 @@ function LogsTable({
|
||||
{logs.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={10}
|
||||
colSpan={11}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No logs found
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { formatMicrodollars } from "../helpers";
|
||||
import { formatMicrodollars, formatTokens } from "../helpers";
|
||||
import { SummaryCard } from "./SummaryCard";
|
||||
import { ProviderTable } from "./ProviderTable";
|
||||
import { UserTable } from "./UserTable";
|
||||
import { LogsTable } from "./LogsTable";
|
||||
import { usePlatformCostContent } from "./usePlatformCostContent";
|
||||
import type { CostBucket } from "@/app/api/__generated__/models/costBucket";
|
||||
|
||||
interface Props {
|
||||
searchParams: {
|
||||
@@ -18,6 +19,7 @@ interface Props {
|
||||
model?: string;
|
||||
block_name?: string;
|
||||
tracking_type?: string;
|
||||
graph_exec_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
@@ -46,6 +48,8 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
setBlockInput,
|
||||
typeInput,
|
||||
setTypeInput,
|
||||
executionIDInput,
|
||||
setExecutionIDInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
@@ -54,6 +58,76 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
handleExport,
|
||||
} = usePlatformCostContent(searchParams);
|
||||
|
||||
const summaryCards: { label: string; value: string; subtitle?: string }[] =
|
||||
dashboard
|
||||
? [
|
||||
{
|
||||
label: "Known Cost",
|
||||
value: formatMicrodollars(dashboard.total_cost_microdollars),
|
||||
subtitle: "From providers that report USD cost",
|
||||
},
|
||||
{
|
||||
label: "Estimated Total",
|
||||
value: formatMicrodollars(totalEstimatedCost),
|
||||
subtitle: "Including per-run cost estimates",
|
||||
},
|
||||
{
|
||||
label: "Total Requests",
|
||||
value: dashboard.total_requests.toLocaleString(),
|
||||
},
|
||||
{
|
||||
label: "Active Users",
|
||||
value: dashboard.total_users.toLocaleString(),
|
||||
},
|
||||
{
|
||||
label: "Avg Cost / Request",
|
||||
value: formatMicrodollars(
|
||||
dashboard.avg_cost_microdollars_per_request ?? 0,
|
||||
),
|
||||
subtitle: "Known cost divided by cost-bearing requests",
|
||||
},
|
||||
{
|
||||
label: "Avg Input Tokens",
|
||||
value: Math.round(
|
||||
dashboard.avg_input_tokens_per_request ?? 0,
|
||||
).toLocaleString(),
|
||||
subtitle: "Prompt tokens per request (context size)",
|
||||
},
|
||||
{
|
||||
label: "Avg Output Tokens",
|
||||
value: Math.round(
|
||||
dashboard.avg_output_tokens_per_request ?? 0,
|
||||
).toLocaleString(),
|
||||
subtitle: "Completion tokens per request (response length)",
|
||||
},
|
||||
{
|
||||
label: "Total Tokens",
|
||||
value: `${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`,
|
||||
subtitle: "Prompt vs completion token split",
|
||||
},
|
||||
{
|
||||
label: "Typical Cost (P50)",
|
||||
value: formatMicrodollars(dashboard.cost_p50_microdollars ?? 0),
|
||||
subtitle: "Median cost per request",
|
||||
},
|
||||
{
|
||||
label: "Upper Cost (P75)",
|
||||
value: formatMicrodollars(dashboard.cost_p75_microdollars ?? 0),
|
||||
subtitle: "75th percentile cost",
|
||||
},
|
||||
{
|
||||
label: "High Cost (P95)",
|
||||
value: formatMicrodollars(dashboard.cost_p95_microdollars ?? 0),
|
||||
subtitle: "95th percentile cost",
|
||||
},
|
||||
{
|
||||
label: "Peak Cost (P99)",
|
||||
value: formatMicrodollars(dashboard.cost_p99_microdollars ?? 0),
|
||||
subtitle: "99th percentile cost",
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
|
||||
@@ -164,6 +238,22 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
onChange={(e) => setTypeInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="execution-id-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
Execution ID
|
||||
</label>
|
||||
<input
|
||||
id="execution-id-filter"
|
||||
type="text"
|
||||
placeholder="Filter by execution"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={executionIDInput}
|
||||
onChange={(e) => setExecutionIDInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleFilter}
|
||||
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
|
||||
@@ -179,6 +269,7 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
setModelInput("");
|
||||
setBlockInput("");
|
||||
setTypeInput("");
|
||||
setExecutionIDInput("");
|
||||
updateUrl({
|
||||
start: "",
|
||||
end: "",
|
||||
@@ -187,6 +278,7 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
model: "",
|
||||
block_name: "",
|
||||
tracking_type: "",
|
||||
graph_exec_id: "",
|
||||
page: "1",
|
||||
});
|
||||
}}
|
||||
@@ -204,37 +296,54 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
|
||||
{loading ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
|
||||
{/* 12 skeleton placeholders — one per summary card */}
|
||||
{Array.from({ length: 12 }, (_, i) => (
|
||||
<Skeleton key={i} className="h-20 rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
<Skeleton className="h-32 rounded-lg" />
|
||||
<Skeleton className="h-8 w-48 rounded" />
|
||||
<Skeleton className="h-64 rounded-lg" />
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{dashboard && (
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
<SummaryCard
|
||||
label="Known Cost"
|
||||
value={formatMicrodollars(dashboard.total_cost_microdollars)}
|
||||
subtitle="From providers that report USD cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Estimated Total"
|
||||
value={formatMicrodollars(totalEstimatedCost)}
|
||||
subtitle="Including per-run cost estimates"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Requests"
|
||||
value={dashboard.total_requests.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Active Users"
|
||||
value={dashboard.total_users.toLocaleString()}
|
||||
/>
|
||||
</div>
|
||||
<>
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
|
||||
{summaryCards.map((card) => (
|
||||
<SummaryCard
|
||||
key={card.label}
|
||||
label={card.label}
|
||||
value={card.value}
|
||||
subtitle={card.subtitle}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{dashboard.cost_buckets && dashboard.cost_buckets.length > 0 && (
|
||||
<div className="rounded-lg border p-4">
|
||||
<h3 className="mb-3 text-sm font-medium">
|
||||
Cost Distribution by Bucket
|
||||
</h3>
|
||||
<div className="grid grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-6">
|
||||
{dashboard.cost_buckets.map((b: CostBucket) => (
|
||||
<div
|
||||
key={b.bucket}
|
||||
className="flex flex-col items-center rounded border p-2 text-center"
|
||||
>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{b.bucket}
|
||||
</span>
|
||||
<span className="text-lg font-semibold">
|
||||
{b.count.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
<div
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
defaultRateFor,
|
||||
estimateCostForRow,
|
||||
formatMicrodollars,
|
||||
formatTokens,
|
||||
rateKey,
|
||||
rateUnitLabel,
|
||||
trackingValue,
|
||||
@@ -33,6 +34,20 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Usage
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
|
||||
>
|
||||
Input Tokens
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
|
||||
>
|
||||
Output Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
@@ -74,6 +89,16 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
<TrackingBadge trackingType={row.tracking_type} />
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_input_tokens > 0
|
||||
? formatTokens(row.total_input_tokens)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_output_tokens > 0
|
||||
? formatTokens(row.total_output_tokens)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
@@ -124,7 +149,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={8}
|
||||
colSpan={10}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
|
||||
@@ -26,6 +26,9 @@ function UserTable({ data }: Props) {
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Output Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Avg Cost / Req
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
@@ -54,12 +57,21 @@ function UserTable({ data }: Props) {
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_output_tokens)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{(row.cost_bearing_request_count ?? 0) > 0 &&
|
||||
row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(
|
||||
row.total_cost_microdollars /
|
||||
(row.cost_bearing_request_count ?? 1),
|
||||
)
|
||||
: "-"}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={5}
|
||||
colSpan={6}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
|
||||
@@ -23,6 +23,7 @@ interface InitialSearchParams {
|
||||
model?: string;
|
||||
block_name?: string;
|
||||
tracking_type?: string;
|
||||
graph_exec_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
}
|
||||
@@ -43,6 +44,8 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
urlParams.get("block_name") || searchParams.block_name || "";
|
||||
const typeFilter =
|
||||
urlParams.get("tracking_type") || searchParams.tracking_type || "";
|
||||
const executionIDFilter =
|
||||
urlParams.get("graph_exec_id") || searchParams.graph_exec_id || "";
|
||||
|
||||
const [startInput, setStartInput] = useState(toLocalInput(startDate));
|
||||
const [endInput, setEndInput] = useState(toLocalInput(endDate));
|
||||
@@ -51,6 +54,7 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
const [modelInput, setModelInput] = useState(modelFilter);
|
||||
const [blockInput, setBlockInput] = useState(blockFilter);
|
||||
const [typeInput, setTypeInput] = useState(typeFilter);
|
||||
const [executionIDInput, setExecutionIDInput] = useState(executionIDFilter);
|
||||
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
|
||||
{},
|
||||
);
|
||||
@@ -67,6 +71,7 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
model: modelFilter || undefined,
|
||||
block_name: blockFilter || undefined,
|
||||
tracking_type: typeFilter || undefined,
|
||||
graph_exec_id: executionIDFilter || undefined,
|
||||
};
|
||||
|
||||
const {
|
||||
@@ -115,6 +120,7 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
model: modelInput,
|
||||
block_name: blockInput,
|
||||
tracking_type: typeInput,
|
||||
graph_exec_id: executionIDInput,
|
||||
page: "1",
|
||||
});
|
||||
}
|
||||
@@ -185,6 +191,8 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
setBlockInput,
|
||||
typeInput,
|
||||
setTypeInput,
|
||||
executionIDInput,
|
||||
setExecutionIDInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
|
||||
@@ -7,6 +7,10 @@ type SearchParams = {
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
model?: string;
|
||||
block_name?: string;
|
||||
tracking_type?: string;
|
||||
graph_exec_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
|
||||
@@ -113,8 +113,8 @@ export function CopilotPage() {
|
||||
// Rate limit reset
|
||||
rateLimitMessage,
|
||||
dismissRateLimit,
|
||||
// Dry run dev toggle
|
||||
isDryRun,
|
||||
// Dry run session state
|
||||
sessionDryRun,
|
||||
} = useCopilotPage();
|
||||
|
||||
const {
|
||||
@@ -176,10 +176,15 @@ export function CopilotPage() {
|
||||
>
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<NotificationBanner />
|
||||
{isDryRun && (
|
||||
{/* Test mode banner: only shown when the CURRENT session is confirmed to be
|
||||
a dry_run session via its immutable metadata. Never shown based on the
|
||||
global isDryRun store preference alone — that only predicts future sessions
|
||||
and would mislead users browsing non-dry-run sessions while the toggle is on.
|
||||
The DryRunToggleButton (visible on new chats) already communicates the preference. */}
|
||||
{sessionId && sessionDryRun && (
|
||||
<div className="flex items-center justify-center gap-1.5 bg-amber-50 px-3 py-1.5 text-xs font-medium text-amber-800">
|
||||
<Flask size={13} weight="bold" />
|
||||
Test mode — new sessions use dry_run=true
|
||||
Test mode — this session runs agents as simulation
|
||||
</div>
|
||||
)}
|
||||
{/* Drop overlay */}
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
import { render, screen, cleanup } from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { CopilotPage } from "../CopilotPage";
|
||||
|
||||
// Mock child components that are complex and not under test here
|
||||
vi.mock("../components/ChatContainer/ChatContainer", () => ({
|
||||
ChatContainer: () => <div data-testid="chat-container" />,
|
||||
}));
|
||||
vi.mock("../components/ChatSidebar/ChatSidebar", () => ({
|
||||
ChatSidebar: () => <div data-testid="chat-sidebar" />,
|
||||
}));
|
||||
vi.mock("../components/DeleteChatDialog/DeleteChatDialog", () => ({
|
||||
DeleteChatDialog: () => null,
|
||||
}));
|
||||
vi.mock("../components/MobileDrawer/MobileDrawer", () => ({
|
||||
MobileDrawer: () => null,
|
||||
}));
|
||||
vi.mock("../components/MobileHeader/MobileHeader", () => ({
|
||||
MobileHeader: () => null,
|
||||
}));
|
||||
vi.mock("../components/NotificationBanner/NotificationBanner", () => ({
|
||||
NotificationBanner: () => null,
|
||||
}));
|
||||
vi.mock("../components/NotificationDialog/NotificationDialog", () => ({
|
||||
NotificationDialog: () => null,
|
||||
}));
|
||||
vi.mock("../components/RateLimitResetDialog/RateLimitResetDialog", () => ({
|
||||
RateLimitResetDialog: () => null,
|
||||
}));
|
||||
vi.mock("../components/ScaleLoader/ScaleLoader", () => ({
|
||||
ScaleLoader: () => <div data-testid="scale-loader" />,
|
||||
}));
|
||||
vi.mock("../components/ArtifactPanel/ArtifactPanel", () => ({
|
||||
ArtifactPanel: () => null,
|
||||
}));
|
||||
vi.mock("@/components/ui/sidebar", () => ({
|
||||
SidebarProvider: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
// Mock hooks that hit the network
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
useGetV2GetCopilotUsage: () => ({
|
||||
data: undefined,
|
||||
isSuccess: false,
|
||||
isError: false,
|
||||
}),
|
||||
}));
|
||||
vi.mock("@/hooks/useCredits", () => ({
|
||||
default: () => ({ credits: null, fetchCredits: vi.fn() }),
|
||||
}));
|
||||
vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
Flag: {
|
||||
ENABLE_PLATFORM_PAYMENT: "ENABLE_PLATFORM_PAYMENT",
|
||||
ARTIFACTS: "ARTIFACTS",
|
||||
CHAT_MODE_OPTION: "CHAT_MODE_OPTION",
|
||||
},
|
||||
useGetFlag: () => false,
|
||||
}));
|
||||
|
||||
// Build the base mock return value for useCopilotPage
|
||||
const basePageState = {
|
||||
sessionId: null as string | null,
|
||||
messages: [],
|
||||
status: "ready" as const,
|
||||
error: undefined,
|
||||
stop: vi.fn(),
|
||||
isReconnecting: false,
|
||||
isSyncing: false,
|
||||
createSession: vi.fn(),
|
||||
onSend: vi.fn(),
|
||||
isLoadingSession: false,
|
||||
isSessionError: false,
|
||||
isCreatingSession: false,
|
||||
isUploadingFiles: false,
|
||||
isUserLoading: false,
|
||||
isLoggedIn: true,
|
||||
hasMoreMessages: false,
|
||||
isLoadingMore: false,
|
||||
loadMore: vi.fn(),
|
||||
isMobile: false,
|
||||
isDrawerOpen: false,
|
||||
sessions: [],
|
||||
isLoadingSessions: false,
|
||||
handleOpenDrawer: vi.fn(),
|
||||
handleCloseDrawer: vi.fn(),
|
||||
handleDrawerOpenChange: vi.fn(),
|
||||
handleSelectSession: vi.fn(),
|
||||
handleNewChat: vi.fn(),
|
||||
sessionToDelete: null,
|
||||
isDeleting: false,
|
||||
handleConfirmDelete: vi.fn(),
|
||||
handleCancelDelete: vi.fn(),
|
||||
historicalDurations: {},
|
||||
rateLimitMessage: null,
|
||||
dismissRateLimit: vi.fn(),
|
||||
isDryRun: false,
|
||||
sessionDryRun: false,
|
||||
};
|
||||
|
||||
const mockUseCopilotPage = vi.fn(() => basePageState);
|
||||
|
||||
vi.mock("../useCopilotPage", () => ({
|
||||
useCopilotPage: () => mockUseCopilotPage(),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseCopilotPage.mockReset();
|
||||
mockUseCopilotPage.mockImplementation(() => basePageState);
|
||||
});
|
||||
|
||||
describe("CopilotPage test-mode banner", () => {
|
||||
it("does not show test-mode banner when there is no active session", () => {
|
||||
render(<CopilotPage />);
|
||||
expect(
|
||||
screen.queryByText(/test mode.*this session runs agents/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("does not show test-mode banner when session exists but sessionDryRun is false", () => {
|
||||
mockUseCopilotPage.mockReturnValue({
|
||||
...basePageState,
|
||||
sessionId: "session-abc",
|
||||
sessionDryRun: false,
|
||||
});
|
||||
render(<CopilotPage />);
|
||||
expect(
|
||||
screen.queryByText(/test mode.*this session runs agents/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("shows test-mode banner when session exists and sessionDryRun is true", () => {
|
||||
mockUseCopilotPage.mockReturnValue({
|
||||
...basePageState,
|
||||
sessionId: "session-abc",
|
||||
sessionDryRun: true,
|
||||
});
|
||||
render(<CopilotPage />);
|
||||
expect(
|
||||
screen.getByText(/test mode.*this session runs agents/i),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not show test-mode banner when sessionDryRun is true but no sessionId", () => {
|
||||
mockUseCopilotPage.mockReturnValue({
|
||||
...basePageState,
|
||||
sessionId: null,
|
||||
sessionDryRun: true,
|
||||
});
|
||||
render(<CopilotPage />);
|
||||
expect(
|
||||
screen.queryByText(/test mode.*this session runs agents/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("shows loading spinner when user is loading", () => {
|
||||
mockUseCopilotPage.mockReturnValue({
|
||||
...basePageState,
|
||||
isUserLoading: true,
|
||||
isLoggedIn: false,
|
||||
});
|
||||
render(<CopilotPage />);
|
||||
expect(screen.getByTestId("scale-loader")).toBeDefined();
|
||||
expect(screen.queryByTestId("chat-container")).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,11 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { IMPERSONATION_HEADER_NAME } from "@/lib/constants";
|
||||
import { getCopilotAuthHeaders } from "../helpers";
|
||||
import {
|
||||
getCopilotAuthHeaders,
|
||||
getSendSuppressionReason,
|
||||
resolveSessionDryRun,
|
||||
} from "../helpers";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
vi.mock("@/lib/supabase/actions", () => ({
|
||||
getWebSocketToken: vi.fn(),
|
||||
@@ -16,6 +21,42 @@ import { getSystemHeaders } from "@/lib/impersonation";
|
||||
const mockGetWebSocketToken = vi.mocked(getWebSocketToken);
|
||||
const mockGetSystemHeaders = vi.mocked(getSystemHeaders);
|
||||
|
||||
describe("resolveSessionDryRun", () => {
|
||||
it("returns false when queryData is null", () => {
|
||||
expect(resolveSessionDryRun(null)).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when queryData is undefined", () => {
|
||||
expect(resolveSessionDryRun(undefined)).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when status is not 200", () => {
|
||||
expect(resolveSessionDryRun({ status: 404 })).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when status is 200 but metadata.dry_run is false", () => {
|
||||
expect(
|
||||
resolveSessionDryRun({
|
||||
status: 200,
|
||||
data: { metadata: { dry_run: false } },
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when status is 200 but metadata is missing", () => {
|
||||
expect(resolveSessionDryRun({ status: 200, data: {} })).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when status is 200 and metadata.dry_run is true", () => {
|
||||
expect(
|
||||
resolveSessionDryRun({
|
||||
status: 200,
|
||||
data: { metadata: { dry_run: true } },
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getCopilotAuthHeaders", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
@@ -72,3 +113,71 @@ describe("getCopilotAuthHeaders", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getSendSuppressionReason ─────────────────────────────────────────────────
|
||||
|
||||
function makeUserMsg(text: string): UIMessage {
|
||||
return {
|
||||
id: "msg-1",
|
||||
role: "user",
|
||||
content: text,
|
||||
parts: [{ type: "text", text }],
|
||||
} as UIMessage;
|
||||
}
|
||||
|
||||
describe("getSendSuppressionReason", () => {
|
||||
it("returns null when no dedup context exists (fresh ref)", () => {
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: null,
|
||||
messages: [],
|
||||
});
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("returns 'reconnecting' when reconnect is scheduled regardless of text", () => {
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: true,
|
||||
lastSubmittedText: null,
|
||||
messages: [],
|
||||
});
|
||||
expect(result).toBe("reconnecting");
|
||||
});
|
||||
|
||||
it("returns 'duplicate' when same text was submitted and is the last user message", () => {
|
||||
// This is the core regression test: after a successful turn the ref
|
||||
// is intentionally NOT cleared to null, so submitting the same text
|
||||
// again is caught here.
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages: [makeUserMsg("hello")],
|
||||
});
|
||||
expect(result).toBe("duplicate");
|
||||
});
|
||||
|
||||
it("returns null when same ref text but different last user message (different question)", () => {
|
||||
// User asked "hello" before, got a reply, then asked a different question
|
||||
// — the last user message in chat is now different, so no suppression.
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages: [makeUserMsg("hello"), makeUserMsg("something else")],
|
||||
});
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null when text differs from lastSubmittedText", () => {
|
||||
const result = getSendSuppressionReason({
|
||||
text: "new question",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "old question",
|
||||
messages: [makeUserMsg("old question")],
|
||||
});
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { describe, expect, it, beforeEach, vi } from "vitest";
|
||||
import { describe, expect, it, beforeEach, afterEach, vi } from "vitest";
|
||||
import { useCopilotUIStore } from "../store";
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
@@ -22,7 +22,8 @@ describe("useCopilotUIStore", () => {
|
||||
isNotificationsEnabled: false,
|
||||
isSoundEnabled: true,
|
||||
showNotificationDialog: false,
|
||||
copilotMode: "extended_thinking",
|
||||
copilotChatMode: "extended_thinking",
|
||||
copilotLlmModel: "standard",
|
||||
});
|
||||
});
|
||||
|
||||
@@ -154,35 +155,52 @@ describe("useCopilotUIStore", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("copilotMode", () => {
|
||||
describe("copilotChatMode", () => {
|
||||
it("defaults to extended_thinking", () => {
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe(
|
||||
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
|
||||
"extended_thinking",
|
||||
);
|
||||
});
|
||||
|
||||
it("sets mode to fast", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe("fast");
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
expect(useCopilotUIStore.getState().copilotChatMode).toBe("fast");
|
||||
});
|
||||
|
||||
it("sets mode back to extended_thinking", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotMode("extended_thinking");
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe(
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotChatMode("extended_thinking");
|
||||
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
|
||||
"extended_thinking",
|
||||
);
|
||||
});
|
||||
|
||||
it("does not persist mode to localStorage", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
|
||||
it("persists mode to localStorage", () => {
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBe("fast");
|
||||
});
|
||||
});
|
||||
|
||||
describe("copilotLlmModel", () => {
|
||||
it("defaults to standard", () => {
|
||||
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("standard");
|
||||
});
|
||||
|
||||
it("sets model to advanced", () => {
|
||||
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
|
||||
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("advanced");
|
||||
});
|
||||
|
||||
it("persists model to localStorage", () => {
|
||||
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
|
||||
expect(window.localStorage.getItem("copilot-model")).toBe("advanced");
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearCopilotLocalData", () => {
|
||||
it("resets state and clears localStorage keys", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
|
||||
useCopilotUIStore.getState().setNotificationsEnabled(true);
|
||||
useCopilotUIStore.getState().toggleSound();
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
@@ -190,7 +208,8 @@ describe("useCopilotUIStore", () => {
|
||||
useCopilotUIStore.getState().clearCopilotLocalData();
|
||||
|
||||
const state = useCopilotUIStore.getState();
|
||||
expect(state.copilotMode).toBe("extended_thinking");
|
||||
expect(state.copilotChatMode).toBe("extended_thinking");
|
||||
expect(state.copilotLlmModel).toBe("standard");
|
||||
expect(state.isNotificationsEnabled).toBe(false);
|
||||
expect(state.isSoundEnabled).toBe(true);
|
||||
expect(state.completedSessionIDs.size).toBe(0);
|
||||
@@ -198,6 +217,8 @@ describe("useCopilotUIStore", () => {
|
||||
window.localStorage.getItem("copilot-notifications-enabled"),
|
||||
).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-model")).toBeNull();
|
||||
expect(
|
||||
window.localStorage.getItem("copilot-completed-sessions"),
|
||||
).toBeNull();
|
||||
@@ -222,3 +243,24 @@ describe("useCopilotUIStore", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("useCopilotUIStore localStorage initialisation", () => {
|
||||
afterEach(() => {
|
||||
vi.resetModules();
|
||||
window.localStorage.clear();
|
||||
});
|
||||
|
||||
it("reads fast chat mode from localStorage on store creation", async () => {
|
||||
window.localStorage.setItem("copilot-mode", "fast");
|
||||
vi.resetModules();
|
||||
const { useCopilotUIStore: fresh } = await import("../store");
|
||||
expect(fresh.getState().copilotChatMode).toBe("fast");
|
||||
});
|
||||
|
||||
it("reads advanced model from localStorage on store creation", async () => {
|
||||
window.localStorage.setItem("copilot-model", "advanced");
|
||||
vi.resetModules();
|
||||
const { useCopilotUIStore: fresh } = await import("../store");
|
||||
expect(fresh.getState().copilotLlmModel).toBe("advanced");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { ArtifactCard } from "./ArtifactCard";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
const meta: Meta<typeof ArtifactCard> = {
|
||||
title: "Copilot/ArtifactCard",
|
||||
component: ArtifactCard,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "padded",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Inline artifact card rendered in chat messages. Openable artifacts show a caret and open the ArtifactPanel on click. Download-only artifacts trigger a file download.",
|
||||
},
|
||||
},
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div className="w-96">
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const OpenableHTML: Story = {
|
||||
name: "Openable (HTML)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "dashboard.html",
|
||||
mimeType: "text/html",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenableImage: Story = {
|
||||
name: "Openable (Image)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "img-card",
|
||||
title: "chart.png",
|
||||
mimeType: "image/png",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenableCode: Story = {
|
||||
name: "Openable (Code)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "script.py",
|
||||
mimeType: "text/x-python",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const DownloadOnly: Story = {
|
||||
name: "Download Only (ZIP)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "archive.zip",
|
||||
mimeType: "application/zip",
|
||||
sizeBytes: 2_500_000,
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const PreviewableVideo: Story = {
|
||||
name: "Previewable (Video)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "demo.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sizeBytes: 15_000_000,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Videos with supported formats (MP4, WebM, M4V) are previewable inline in the artifact panel.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const WithSize: Story = {
|
||||
name: "With File Size",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "data.csv",
|
||||
mimeType: "text/csv",
|
||||
sizeBytes: 524_288,
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const UserUpload: Story = {
|
||||
name: "User Upload Origin",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "requirements.txt",
|
||||
mimeType: "text/plain",
|
||||
origin: "user-upload",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const ActiveState: Story = {
|
||||
name: "Active (Panel Open)",
|
||||
args: {
|
||||
artifact: makeArtifact({ id: "active-card" }),
|
||||
},
|
||||
decorators: [
|
||||
(Story) => {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: makeArtifact({ id: "active-card" }),
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -0,0 +1,223 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { ArtifactPanel } from "./ArtifactPanel";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
|
||||
const PROXY_BASE = "/api/proxy/api/workspace/files";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/file-001/download`,
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function openPanelWith(artifact: ArtifactRef) {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: artifact,
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const meta: Meta<typeof ArtifactPanel> = {
|
||||
title: "Copilot/ArtifactPanel",
|
||||
component: ArtifactPanel,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "fullscreen",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Side panel for previewing workspace artifacts. Supports resize, minimize, maximize, and navigation history. Bug: panel auto-opens on chat switch instead of staying collapsed.",
|
||||
},
|
||||
},
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div className="flex h-[600px] w-full">
|
||||
<div className="flex-1 bg-zinc-50 p-8">
|
||||
<p className="text-sm text-zinc-500">Chat area</p>
|
||||
</div>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const OpenWithTextArtifact: Story = {
|
||||
name: "Open — Text File",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({ title: "notes.txt", mimeType: "text/plain" }),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/file-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
"These are some notes from the agent execution.\n\nKey findings:\n1. Performance improved by 23%\n2. Memory usage reduced\n3. Error rate dropped to 0.1%",
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenWithHTMLArtifact: Story = {
|
||||
name: "Open — HTML",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({
|
||||
id: "html-panel",
|
||||
title: "dashboard.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/html-panel/download`,
|
||||
}),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/html-panel/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`<!DOCTYPE html><html><body class="p-8 font-sans"><h1 class="text-2xl font-bold text-indigo-600">Dashboard</h1><p class="mt-2 text-gray-600">HTML artifact in the panel.</p></body></html>`,
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenWithImageArtifact: Story = {
|
||||
name: "Open — Image (Bug: No Loading State)",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({
|
||||
id: "img-panel",
|
||||
title: "chart.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: `${PROXY_BASE}/img-panel/download`,
|
||||
}),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/img-panel/download`, () => {
|
||||
return HttpResponse.text(
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" width="500" height="300"><rect width="500" height="300" fill="#dbeafe"/><text x="250" y="150" text-anchor="middle" fill="#1e40af" font-size="20">Image Preview (no skeleton)</text></svg>',
|
||||
{ headers: { "Content-Type": "image/svg+xml" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"**BUG:** Image artifacts render with a bare `<img>` tag — no loading skeleton or error handling. Compare with text/HTML artifacts which show a proper skeleton while loading.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const MinimizedStrip: Story = {
|
||||
name: "Minimized",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: true,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: makeArtifact(),
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
export const ErrorState: Story = {
|
||||
name: "Error — Failed to Load (Stale Artifact)",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({
|
||||
id: "stale-panel",
|
||||
title: "old-report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/stale-panel/download`,
|
||||
}),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/stale-panel/download`, () => {
|
||||
return new HttpResponse(null, { status: 404 });
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Shows what users see when opening a previously generated artifact that no longer exists on the backend (404). The 'Try again' button retries the fetch.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const Closed: Story = {
|
||||
name: "Closed (Default State)",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: false,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"The default state — panel is closed. It should only open when a user clicks on an artifact card in the chat.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,413 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { downloadArtifact } from "../downloadArtifact";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "report.pdf",
|
||||
mimeType: "application/pdf",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("downloadArtifact", () => {
|
||||
let clickSpy: ReturnType<typeof vi.fn>;
|
||||
let removeSpy: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
clickSpy = vi.fn();
|
||||
removeSpy = vi.fn();
|
||||
|
||||
vi.stubGlobal(
|
||||
"URL",
|
||||
Object.assign(URL, {
|
||||
createObjectURL: vi.fn().mockReturnValue("blob:fake-url"),
|
||||
revokeObjectURL: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.spyOn(document, "createElement").mockReturnValue({
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
} as unknown as HTMLAnchorElement);
|
||||
|
||||
vi.spyOn(document.body, "appendChild").mockImplementation(
|
||||
(node) => node as ChildNode,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("downloads file successfully on 200 response", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["pdf content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
"/api/proxy/api/workspace/files/file-001/download",
|
||||
);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
expect(removeSpy).toHaveBeenCalled();
|
||||
expect(URL.revokeObjectURL).toHaveBeenCalledWith("blob:fake-url");
|
||||
});
|
||||
|
||||
it("rejects on persistent server error after exhausting retries", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 500",
|
||||
);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("rejects on persistent network error after exhausting retries", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.reject(new Error("Network error"));
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Network error",
|
||||
);
|
||||
expect(callCount).toBe(3);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("retries on transient network error and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.reject(new Error("Connection reset"));
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("retries on transient 500 and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({ ok: false, status: 500 });
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
// Should succeed on second attempt
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("sanitizes dangerous filenames", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: "../../../etc/passwd" }));
|
||||
|
||||
expect(anchor.download).not.toContain("..");
|
||||
expect(anchor.download).not.toContain("/");
|
||||
});
|
||||
|
||||
// ── Transient retry codes ─────────────────────────────────────────
|
||||
|
||||
it("retries on 408 (Request Timeout) and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({ ok: false, status: 408 });
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("retries on 429 (Too Many Requests) and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({ ok: false, status: 429 });
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── Non-transient errors ──────────────────────────────────────────
|
||||
|
||||
it("rejects immediately on 403 (non-transient) without retry", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({ ok: false, status: 403 });
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 403",
|
||||
);
|
||||
expect(callCount).toBe(1);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("rejects immediately on 404 without retry", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({ ok: false, status: 404 });
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 404",
|
||||
);
|
||||
expect(callCount).toBe(1);
|
||||
});
|
||||
|
||||
// ── Exhausted retries ─────────────────────────────────────────────
|
||||
|
||||
it("rejects after exhausting all retries on persistent 500", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({ ok: false, status: 500 });
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 500",
|
||||
);
|
||||
// Initial attempt + 2 retries = 3 total
|
||||
expect(callCount).toBe(3);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── Filename edge cases ───────────────────────────────────────────
|
||||
|
||||
it("falls back to 'download' when title is empty", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: "" }));
|
||||
expect(anchor.download).toBe("download");
|
||||
});
|
||||
|
||||
it("falls back to 'download' when title is only dots", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
// Dot-only names should not produce a hidden or empty filename.
|
||||
await downloadArtifact(makeArtifact({ title: "...." }));
|
||||
expect(anchor.download).toBe("download");
|
||||
});
|
||||
|
||||
it("replaces special chars with underscores (not empty)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: '***???"' }));
|
||||
// Special chars become underscores, not removed
|
||||
expect(anchor.download).toBe("_______");
|
||||
});
|
||||
|
||||
it("strips leading dots from filename", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: "...hidden.txt" }));
|
||||
expect(anchor.download).not.toMatch(/^\./);
|
||||
expect(anchor.download).toContain("hidden.txt");
|
||||
});
|
||||
|
||||
it("replaces Windows-reserved characters", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(
|
||||
makeArtifact({ title: "file<name>with:bad*chars?.txt" }),
|
||||
);
|
||||
expect(anchor.download).not.toMatch(/[<>:*?]/);
|
||||
});
|
||||
|
||||
it("replaces control characters in filename", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(
|
||||
makeArtifact({ title: "file\x00with\x1fcontrol.txt" }),
|
||||
);
|
||||
expect(anchor.download).not.toMatch(/[\x00-\x1f]/);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,460 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { ArtifactContent } from "./ArtifactContent";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import {
|
||||
Code,
|
||||
File,
|
||||
FileHtml,
|
||||
FileText,
|
||||
Image,
|
||||
Table,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
const PROXY_BASE = "/api/proxy/api/workspace/files";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "test.txt",
|
||||
mimeType: "text/plain",
|
||||
sourceUrl: `${PROXY_BASE}/file-001/download`,
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeClassification(
|
||||
overrides?: Partial<ArtifactClassification>,
|
||||
): ArtifactClassification {
|
||||
return {
|
||||
type: "text",
|
||||
icon: FileText,
|
||||
label: "Text",
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
const meta: Meta<typeof ArtifactContent> = {
|
||||
title: "Copilot/ArtifactContent",
|
||||
component: ArtifactContent,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "padded",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Renders artifact content based on file type classification. Supports images, HTML, code, CSV, JSON, markdown, PDF, and plain text. Bug: image artifacts render as bare <img> with no loading/error states.",
|
||||
},
|
||||
},
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div
|
||||
className="flex h-[500px] w-[600px] flex-col overflow-hidden border border-zinc-200"
|
||||
style={{ resize: "both" }}
|
||||
>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const ImageArtifactPNG: Story = {
|
||||
name: "Image (PNG) — No Loading Skeleton (Bug #1)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "img-png",
|
||||
title: "chart.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: `${PROXY_BASE}/img-png/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({ type: "image", icon: Image }),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/img-png/download`, () => {
|
||||
return HttpResponse.text(
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#e0e7ff"/><text x="200" y="150" text-anchor="middle" fill="#4338ca" font-size="24">PNG Placeholder</text></svg>',
|
||||
{ headers: { "Content-Type": "image/svg+xml" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"**BUG:** This renders a bare `<img>` tag with no loading skeleton or error handling. Compare with WorkspaceFileRenderer which has proper Skeleton + onError states.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const ImageArtifactSVG: Story = {
|
||||
name: "Image (SVG)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "img-svg",
|
||||
title: "diagram.svg",
|
||||
mimeType: "image/svg+xml",
|
||||
sourceUrl: `${PROXY_BASE}/img-svg/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({ type: "image", icon: Image }),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/img-svg/download`, () => {
|
||||
return HttpResponse.text(
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#fef3c7"/><circle cx="200" cy="150" r="80" fill="#f59e0b"/><text x="200" y="155" text-anchor="middle" fill="white" font-size="20">SVG OK</text></svg>',
|
||||
{ headers: { "Content-Type": "image/svg+xml" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const HTMLArtifact: Story = {
|
||||
name: "HTML",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "html-001",
|
||||
title: "page.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/html-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "html",
|
||||
icon: FileHtml,
|
||||
label: "HTML",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/html-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Artifact Preview</title></head>
|
||||
<body class="p-8 font-sans">
|
||||
<h1 class="text-2xl font-bold text-indigo-600 mb-4">HTML Artifact</h1>
|
||||
<p class="text-gray-700">This is an HTML artifact rendered in a sandboxed iframe with Tailwind CSS injected.</p>
|
||||
<div class="mt-4 p-4 bg-blue-50 rounded-lg border border-blue-200">
|
||||
<p class="text-blue-800">Interactive content works via allow-scripts sandbox.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
{ headers: { "Content-Type": "text/html" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CodeArtifact: Story = {
|
||||
name: "Code (Python)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "code-001",
|
||||
title: "analysis.py",
|
||||
mimeType: "text/x-python",
|
||||
sourceUrl: `${PROXY_BASE}/code-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "code",
|
||||
icon: Code,
|
||||
label: "Code",
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/code-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def analyze_data(filepath: str) -> pd.DataFrame:
|
||||
"""Load and analyze CSV data."""
|
||||
df = pd.read_csv(filepath)
|
||||
summary = df.describe()
|
||||
print(f"Loaded {len(df)} rows")
|
||||
return summary
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = analyze_data("data.csv")
|
||||
print(result)`,
|
||||
{ headers: { "Content-Type": "text/plain" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CSVArtifact: Story = {
|
||||
name: "CSV (Spreadsheet)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "csv-001",
|
||||
title: "data.csv",
|
||||
mimeType: "text/csv",
|
||||
sourceUrl: `${PROXY_BASE}/csv-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "csv",
|
||||
icon: Table,
|
||||
label: "Spreadsheet",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/csv-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`Name,Age,City,Score
|
||||
Alice,28,New York,92
|
||||
Bob,35,San Francisco,87
|
||||
Charlie,22,Chicago,95
|
||||
Diana,31,Boston,88
|
||||
Eve,27,Seattle,91`,
|
||||
{ headers: { "Content-Type": "text/csv" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const JSONArtifact: Story = {
|
||||
name: "JSON (Data)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "json-001",
|
||||
title: "config.json",
|
||||
mimeType: "application/json",
|
||||
sourceUrl: `${PROXY_BASE}/json-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "json",
|
||||
icon: Code,
|
||||
label: "Data",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/json-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
JSON.stringify(
|
||||
{
|
||||
name: "AutoGPT Agent",
|
||||
version: "2.0",
|
||||
capabilities: ["web_search", "code_execution", "file_io"],
|
||||
settings: { maxTokens: 4096, temperature: 0.7 },
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
{ headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownArtifact: Story = {
|
||||
name: "Markdown",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "md-001",
|
||||
title: "README.md",
|
||||
mimeType: "text/markdown",
|
||||
sourceUrl: `${PROXY_BASE}/md-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "markdown",
|
||||
icon: FileText,
|
||||
label: "Document",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/md-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`# Project Summary
|
||||
|
||||
## Overview
|
||||
This is a **markdown** artifact rendered through the global renderer registry.
|
||||
|
||||
## Features
|
||||
- Headings and paragraphs
|
||||
- **Bold** and *italic* text
|
||||
- Lists and code blocks
|
||||
|
||||
\`\`\`python
|
||||
print("Hello from markdown!")
|
||||
\`\`\`
|
||||
|
||||
> Blockquotes are also supported.`,
|
||||
{ headers: { "Content-Type": "text/plain" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const PDFArtifact: Story = {
|
||||
name: "PDF",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "pdf-001",
|
||||
title: "report.pdf",
|
||||
mimeType: "application/pdf",
|
||||
sourceUrl: `${PROXY_BASE}/pdf-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "pdf",
|
||||
icon: FileText,
|
||||
label: "PDF",
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/pdf-001/download`, () => {
|
||||
return HttpResponse.arrayBuffer(new ArrayBuffer(100), {
|
||||
headers: { "Content-Type": "application/pdf" },
|
||||
});
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"PDF artifacts are rendered in an unsandboxed iframe using a blob URL (Chromium bug #413851 prevents sandboxed PDF rendering).",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const ErrorState: Story = {
|
||||
name: "Error — Failed to Load Content",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "error-001",
|
||||
title: "old-report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/error-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "html",
|
||||
icon: FileHtml,
|
||||
label: "HTML",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/error-001/download`, () => {
|
||||
return new HttpResponse(null, { status: 404 });
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Shows the error state when an artifact fails to load (e.g., old/expired file returning 404). Includes a 'Try again' retry button.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const LoadingSkeleton: Story = {
|
||||
name: "Loading State",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "loading-001",
|
||||
title: "loading.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/loading-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "html",
|
||||
icon: FileHtml,
|
||||
label: "HTML",
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/loading-001/download`, async () => {
|
||||
// Delay response to show loading state
|
||||
await new Promise((r) => setTimeout(r, 999999));
|
||||
return HttpResponse.text("never resolves");
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Shows the skeleton loading state while content is being fetched.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const DownloadOnly: Story = {
|
||||
name: "Download Only (Binary)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "bin-001",
|
||||
title: "archive.zip",
|
||||
mimeType: "application/zip",
|
||||
sourceUrl: `${PROXY_BASE}/bin-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "download-only",
|
||||
icon: File,
|
||||
label: "File",
|
||||
openable: false,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Download-only files (binary, video, etc.) are not rendered inline. The ArtifactPanel shows nothing for these — they are handled by ArtifactCard with a download button.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
|
||||
import { Suspense } from "react";
|
||||
import { Suspense, useState } from "react";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
@@ -63,6 +64,90 @@ function ArtifactContentLoader({
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactImage({ src, alt }: { src: string; alt: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm text-zinc-500">Failed to load image</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setError(false);
|
||||
setLoaded(false);
|
||||
}}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative flex items-center justify-center p-4">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
|
||||
)}
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={`max-h-full max-w-full object-contain transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
|
||||
onLoad={() => setLoaded(true)}
|
||||
onError={() => setError(true)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactVideo({ src }: { src: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm text-zinc-500">Failed to load video</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setError(false);
|
||||
setLoaded(false);
|
||||
}}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative flex items-center justify-center p-4">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
|
||||
)}
|
||||
<video
|
||||
src={src}
|
||||
controls
|
||||
preload="metadata"
|
||||
className={`max-h-full max-w-full rounded-md transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
|
||||
onLoadedMetadata={() => setLoaded(true)}
|
||||
onError={() => setError(true)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactRenderer({
|
||||
artifact,
|
||||
content,
|
||||
@@ -79,17 +164,19 @@ function ArtifactRenderer({
|
||||
// Image: render directly from URL (no content fetch)
|
||||
if (classification.type === "image") {
|
||||
return (
|
||||
<div className="flex items-center justify-center p-4">
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={artifact.sourceUrl}
|
||||
alt={artifact.title}
|
||||
className="max-h-full max-w-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
<ArtifactImage
|
||||
key={artifact.sourceUrl}
|
||||
src={artifact.sourceUrl}
|
||||
alt={artifact.title}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Video: render with <video> controls (no content fetch)
|
||||
if (classification.type === "video") {
|
||||
return <ArtifactVideo key={artifact.sourceUrl} src={artifact.sourceUrl} />;
|
||||
}
|
||||
|
||||
if (classification.type === "pdf" && pdfUrl) {
|
||||
// No sandbox — Chrome/Edge block PDF rendering in sandboxed iframes
|
||||
// (Chromium bug #413851). The blob URL has a null origin so it can't
|
||||
@@ -164,7 +251,16 @@ function ArtifactRenderer({
|
||||
|
||||
// CSV: pass with explicit metadata so CSVRenderer matches
|
||||
if (classification.type === "csv") {
|
||||
const csvMeta = { mimeType: "text/csv", filename: artifact.title };
|
||||
const normalizedMime = artifact.mimeType
|
||||
?.toLowerCase()
|
||||
.split(";")[0]
|
||||
?.trim();
|
||||
const csvMimeType =
|
||||
normalizedMime === "text/tab-separated-values" ||
|
||||
artifact.title.toLowerCase().endsWith(".tsv")
|
||||
? "text/tab-separated-values"
|
||||
: "text/csv";
|
||||
const csvMeta = { mimeType: csvMimeType, filename: artifact.title };
|
||||
const csvRenderer = globalRegistry.getRenderer(content, csvMeta);
|
||||
if (csvRenderer) {
|
||||
return <div className="p-4">{csvRenderer.render(content, csvMeta)}</div>;
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
import {
|
||||
buildReactArtifactSrcDoc,
|
||||
collectPreviewStyles,
|
||||
transpileReactArtifactSource,
|
||||
} from "./reactArtifactPreview";
|
||||
|
||||
vi.mock("./reactArtifactPreview", () => ({
|
||||
buildReactArtifactSrcDoc: vi.fn(),
|
||||
collectPreviewStyles: vi.fn(),
|
||||
transpileReactArtifactSource: vi.fn(),
|
||||
}));
|
||||
|
||||
describe("ArtifactReactPreview", () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(collectPreviewStyles).mockReturnValue("<style>preview</style>");
|
||||
vi.mocked(buildReactArtifactSrcDoc).mockReturnValue("<html>preview</html>");
|
||||
});
|
||||
|
||||
it("renders an iframe preview after transpilation succeeds", async () => {
|
||||
vi.mocked(transpileReactArtifactSource).mockResolvedValue(
|
||||
"module.exports.default = function Artifact() { return null; };",
|
||||
);
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactReactPreview
|
||||
source="export default function Artifact() { return null; }"
|
||||
title="Artifact.tsx"
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(buildReactArtifactSrcDoc).toHaveBeenCalledWith(
|
||||
"module.exports.default = function Artifact() { return null; };",
|
||||
"Artifact.tsx",
|
||||
"<style>preview</style>",
|
||||
);
|
||||
});
|
||||
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
expect(iframe?.getAttribute("title")).toBe("Artifact.tsx preview");
|
||||
expect(iframe?.getAttribute("srcdoc")).toBe("<html>preview</html>");
|
||||
});
|
||||
|
||||
it("shows a readable error when transpilation fails", async () => {
|
||||
vi.mocked(transpileReactArtifactSource).mockRejectedValue(
|
||||
new Error("Transpile exploded"),
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactReactPreview
|
||||
source="export default function Artifact() {"
|
||||
title="Broken.tsx"
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Failed to render React preview")).toBeTruthy();
|
||||
});
|
||||
|
||||
expect(screen.getByText("Transpile exploded")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,970 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import {
|
||||
cleanup,
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@testing-library/react";
|
||||
import { ArtifactContent } from "../ArtifactContent";
|
||||
import type { ArtifactRef } from "../../../../store";
|
||||
import { classifyArtifact, type ArtifactClassification } from "../../helpers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
|
||||
import { ArtifactReactPreview } from "../ArtifactReactPreview";
|
||||
|
||||
// Mock the renderers so we don't pull in the full renderer dependency tree
|
||||
vi.mock("@/components/contextual/OutputRenderers", () => ({
|
||||
globalRegistry: {
|
||||
getRenderer: vi.fn().mockReturnValue({
|
||||
render: vi.fn((_val: unknown, meta: Record<string, unknown>) => (
|
||||
<div data-testid="global-renderer">
|
||||
rendered:{String(meta?.mimeType ?? "unknown")}
|
||||
</div>
|
||||
)),
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock(
|
||||
"@/components/contextual/OutputRenderers/renderers/CodeRenderer",
|
||||
() => ({
|
||||
codeRenderer: {
|
||||
render: vi.fn((content: string) => (
|
||||
<div data-testid="code-renderer">{content}</div>
|
||||
)),
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock("../ArtifactReactPreview", () => ({
|
||||
ArtifactReactPreview: vi.fn(
|
||||
({ source, title }: { source: string; title: string }) => (
|
||||
<div data-testid="react-preview" data-title={title}>
|
||||
{source}
|
||||
</div>
|
||||
),
|
||||
),
|
||||
}));
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "test.txt",
|
||||
mimeType: "text/plain",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeClassification(
|
||||
overrides?: Partial<ArtifactClassification>,
|
||||
): ArtifactClassification {
|
||||
return {
|
||||
type: "text",
|
||||
icon: vi.fn(() => null) as unknown as ArtifactClassification["icon"],
|
||||
label: "Text",
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("ArtifactContent", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("file content here"),
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
// ── Image ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders image artifact as img tag with loading skeleton", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-001",
|
||||
title: "photo.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-001/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
expect(img).toBeTruthy();
|
||||
expect(img?.getAttribute("src")).toBe(
|
||||
"/api/proxy/api/workspace/files/img-001/download",
|
||||
);
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("image artifact shows loading skeleton before image loads", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-skeleton",
|
||||
title: "photo.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-skeleton/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Skeleton uses animate-pulse class
|
||||
const skeleton = container.querySelector('[class*="animate-pulse"]');
|
||||
expect(skeleton).toBeTruthy();
|
||||
});
|
||||
|
||||
it("image artifact shows error state when image fails to load", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-error",
|
||||
title: "broken.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-error/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
expect(img).toBeTruthy();
|
||||
fireEvent.error(img!);
|
||||
|
||||
const errorAlert = screen.queryByRole("alert");
|
||||
expect(errorAlert).toBeTruthy();
|
||||
expect(screen.queryByText("Failed to load image")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("image retry resets error and re-shows img", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-retry",
|
||||
title: "retry.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-retry/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
fireEvent.error(img!);
|
||||
|
||||
// Should show error state
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load image")).toBeTruthy();
|
||||
});
|
||||
|
||||
// Click "Try again"
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
// Error should be cleared, img should reappear
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load image")).toBeNull();
|
||||
expect(container.querySelector("img")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ── Video ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders video artifact with video tag and controls", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-001",
|
||||
title: "clip.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-001/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const video = container.querySelector("video");
|
||||
expect(video).toBeTruthy();
|
||||
expect(video?.hasAttribute("controls")).toBe(true);
|
||||
expect(video?.getAttribute("src")).toBe(
|
||||
"/api/proxy/api/workspace/files/vid-001/download",
|
||||
);
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("video shows loading skeleton before metadata loads", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-skel",
|
||||
title: "clip.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-skel/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const skeleton = container.querySelector('[class*="animate-pulse"]');
|
||||
expect(skeleton).toBeTruthy();
|
||||
|
||||
// After metadata loads, skeleton should disappear
|
||||
const video = container.querySelector("video");
|
||||
fireEvent.loadedMetadata(video!);
|
||||
|
||||
expect(container.querySelector('[class*="animate-pulse"]')).toBeNull();
|
||||
});
|
||||
|
||||
it("video shows error state when video fails to load", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-error",
|
||||
title: "broken.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-error/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const video = container.querySelector("video");
|
||||
expect(video).toBeTruthy();
|
||||
fireEvent.error(video!);
|
||||
|
||||
const errorAlert = screen.queryByRole("alert");
|
||||
expect(errorAlert).toBeTruthy();
|
||||
expect(screen.queryByText("Failed to load video")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("video retry resets error and re-shows video", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-retry",
|
||||
title: "retry.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-retry/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const video = container.querySelector("video");
|
||||
fireEvent.error(video!);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load video")).toBeTruthy();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load video")).toBeNull();
|
||||
expect(container.querySelector("video")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ── PDF ───────────────────────────────────────────────────────────
|
||||
|
||||
it("renders PDF artifact in unsandboxed iframe with blob URL", async () => {
|
||||
const blobUrl = "blob:http://localhost/fake-pdf-blob";
|
||||
vi.stubGlobal(
|
||||
"URL",
|
||||
Object.assign(URL, {
|
||||
createObjectURL: vi.fn().mockReturnValue(blobUrl),
|
||||
revokeObjectURL: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "pdf-render",
|
||||
title: "report.pdf",
|
||||
mimeType: "application/pdf",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/pdf-render/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "pdf" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("src")).toBe(blobUrl);
|
||||
// No sandbox attribute — Chrome blocks PDF in sandboxed iframes
|
||||
expect(iframe?.hasAttribute("sandbox")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Fetch error ───────────────────────────────────────────────────
|
||||
|
||||
it("shows error state with retry button on fetch failure", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "error-content-test" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const errorText = await screen.findByText("Failed to load content");
|
||||
expect(errorText).toBeTruthy();
|
||||
|
||||
const retryButtons = screen.getAllByRole("button", { name: /try again/i });
|
||||
expect(retryButtons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
// ── HTML ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders HTML content in sandboxed iframe", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () =>
|
||||
Promise.resolve("<html><body><h1>Hello World</h1></body></html>"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "html-001",
|
||||
title: "page.html",
|
||||
mimeType: "text/html",
|
||||
});
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTitle("page.html");
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
});
|
||||
|
||||
// ── Source view ───────────────────────────────────────────────────
|
||||
|
||||
it("renders source view as pre tag", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source code here"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "source-view-test" });
|
||||
const classification = makeClassification({
|
||||
type: "html",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={true}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByText("source code here");
|
||||
const pre = container.querySelector("pre");
|
||||
expect(pre).toBeTruthy();
|
||||
expect(pre?.textContent).toBe("source code here");
|
||||
});
|
||||
|
||||
// ── React ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders react artifacts via ArtifactReactPreview", async () => {
|
||||
const jsxSource = "export default function App() { return <div>Hi</div>; }";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(jsxSource),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "react-001",
|
||||
title: "App.tsx",
|
||||
mimeType: "text/tsx",
|
||||
});
|
||||
const classification = makeClassification({ type: "react" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const preview = await screen.findByTestId("react-preview");
|
||||
expect(preview).toBeTruthy();
|
||||
expect(preview.textContent).toContain(jsxSource);
|
||||
expect(preview.getAttribute("data-title")).toBe("App.tsx");
|
||||
});
|
||||
|
||||
it("routes a concrete props-based TSX artifact into ArtifactReactPreview", async () => {
|
||||
const jsxSource = `
|
||||
import React, { FC, useState } from "react";
|
||||
|
||||
interface ArtifactFile {
|
||||
id: string;
|
||||
name: string;
|
||||
mimeType: string;
|
||||
url: string;
|
||||
sizeBytes: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
files: ArtifactFile[];
|
||||
onSelect: (file: ArtifactFile) => void;
|
||||
}
|
||||
|
||||
export const previewProps: Props = {
|
||||
files: [
|
||||
{
|
||||
id: "1",
|
||||
name: "report.png",
|
||||
mimeType: "image/png",
|
||||
url: "/report.png",
|
||||
sizeBytes: 2048,
|
||||
},
|
||||
],
|
||||
onSelect: () => {},
|
||||
};
|
||||
|
||||
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
|
||||
const handleClick = (file: ArtifactFile) => {
|
||||
setSelected(file.id);
|
||||
onSelect(file);
|
||||
};
|
||||
|
||||
return (
|
||||
<ul>
|
||||
{files.map((file) => (
|
||||
<li key={file.id} onClick={() => handleClick(file)}>
|
||||
<span>{selected === file.id ? "selected" : file.name}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
);
|
||||
};
|
||||
|
||||
export default ArtifactList;
|
||||
`;
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(jsxSource),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "react-props-001",
|
||||
title: "ArtifactList.tsx",
|
||||
mimeType: "text/tsx",
|
||||
});
|
||||
const classification = classifyArtifact(artifact.mimeType, artifact.title);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const preview = await screen.findByTestId("react-preview");
|
||||
expect(preview.textContent).toContain("previewProps");
|
||||
expect(preview.getAttribute("data-title")).toBe("ArtifactList.tsx");
|
||||
expect(vi.mocked(ArtifactReactPreview).mock.calls[0]?.[0]).toEqual(
|
||||
expect.objectContaining({
|
||||
source: expect.stringContaining("export const previewProps"),
|
||||
title: "ArtifactList.tsx",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
// ── Code ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders code artifacts via codeRenderer", async () => {
|
||||
const code = 'def hello():\n print("hi")';
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(code),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "code-render-001",
|
||||
title: "script.py",
|
||||
mimeType: "text/x-python",
|
||||
});
|
||||
const classification = makeClassification({ type: "code" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("code-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain(code);
|
||||
});
|
||||
|
||||
it.each([
|
||||
{
|
||||
filename: "events.jsonl",
|
||||
mimeType: "application/x-ndjson",
|
||||
content: '{"event":"start"}\n{"event":"finish"}',
|
||||
},
|
||||
{
|
||||
filename: ".env.local",
|
||||
mimeType: "text/plain",
|
||||
content: "OPENAI_API_KEY=test\nDEBUG=true",
|
||||
},
|
||||
{
|
||||
filename: "Dockerfile",
|
||||
mimeType: "text/plain",
|
||||
content: "FROM node:20\nRUN pnpm install",
|
||||
},
|
||||
{
|
||||
filename: "schema.graphql",
|
||||
mimeType: "text/plain",
|
||||
content: "type Query { viewer: User }",
|
||||
},
|
||||
])(
|
||||
"renders concrete code artifact $filename through codeRenderer",
|
||||
async ({ filename, mimeType, content }) => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(content),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: `code-${filename}`,
|
||||
title: filename,
|
||||
mimeType,
|
||||
});
|
||||
const classification = classifyArtifact(
|
||||
artifact.mimeType,
|
||||
artifact.title,
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTestId("code-renderer");
|
||||
|
||||
expect(classification.type).toBe("code");
|
||||
expect(vi.mocked(codeRenderer.render)).toHaveBeenCalledWith(
|
||||
content,
|
||||
expect.objectContaining({
|
||||
filename,
|
||||
mimeType,
|
||||
type: "code",
|
||||
}),
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
// ── JSON ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders valid JSON via globalRegistry", async () => {
|
||||
const jsonContent = JSON.stringify({ key: "value" }, null, 2);
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(jsonContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "json-render-001",
|
||||
title: "data.json",
|
||||
mimeType: "application/json",
|
||||
});
|
||||
const classification = makeClassification({ type: "json" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("application/json");
|
||||
});
|
||||
|
||||
it("renders invalid JSON as fallback pre tag", async () => {
|
||||
const { globalRegistry } = await import(
|
||||
"@/components/contextual/OutputRenderers"
|
||||
);
|
||||
const originalImpl = vi
|
||||
.mocked(globalRegistry.getRenderer)
|
||||
.getMockImplementation();
|
||||
|
||||
// For invalid JSON, JSON.parse throws, then the registry fallback
|
||||
// also returns null → falls through to <pre>
|
||||
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("{invalid json!!!"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "json-invalid-001",
|
||||
title: "bad.json",
|
||||
mimeType: "application/json",
|
||||
});
|
||||
const classification = makeClassification({ type: "json" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
const pre = container.querySelector("pre");
|
||||
expect(pre).toBeTruthy();
|
||||
expect(pre?.textContent).toBe("{invalid json!!!");
|
||||
});
|
||||
|
||||
// Restore
|
||||
if (originalImpl) {
|
||||
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
|
||||
}
|
||||
});
|
||||
|
||||
// ── CSV ───────────────────────────────────────────────────────────
|
||||
|
||||
it("renders CSV via globalRegistry with text/csv metadata", async () => {
|
||||
const csvContent = "Name,Age\nAlice,30\nBob,25";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(csvContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "csv-render-001",
|
||||
title: "data.csv",
|
||||
mimeType: "text/csv",
|
||||
});
|
||||
const classification = makeClassification({
|
||||
type: "csv",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("text/csv");
|
||||
});
|
||||
|
||||
it("renders TSV via globalRegistry with tab-separated metadata", async () => {
|
||||
const tsvContent = "Name\tAge\nAlice\t30\nBob\t25";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(tsvContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "tsv-render-001",
|
||||
title: "data.tsv",
|
||||
mimeType: "text/tab-separated-values",
|
||||
});
|
||||
const classification = makeClassification({
|
||||
type: "csv",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("text/tab-separated-values");
|
||||
});
|
||||
|
||||
// ── Markdown ──────────────────────────────────────────────────────
|
||||
|
||||
it("renders markdown via globalRegistry", async () => {
|
||||
const mdContent = "# Hello\n\nThis is **markdown**.";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(mdContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "md-render-001",
|
||||
title: "README.md",
|
||||
mimeType: "text/markdown",
|
||||
});
|
||||
const classification = makeClassification({
|
||||
type: "markdown",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("text/markdown");
|
||||
});
|
||||
|
||||
// ── Text fallback ─────────────────────────────────────────────────
|
||||
|
||||
it("renders text artifacts via globalRegistry fallback", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("plain text content"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "text-render-001",
|
||||
title: "notes.txt",
|
||||
mimeType: "text/plain",
|
||||
});
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
});
|
||||
|
||||
it.each([
|
||||
{
|
||||
filename: "calendar.ics",
|
||||
mimeType: "text/calendar",
|
||||
content: "BEGIN:VCALENDAR\nVERSION:2.0\nEND:VCALENDAR",
|
||||
},
|
||||
{
|
||||
filename: "contact.vcf",
|
||||
mimeType: "text/vcard",
|
||||
content: "BEGIN:VCARD\nVERSION:4.0\nFN:Alice Example\nEND:VCARD",
|
||||
},
|
||||
])(
|
||||
"renders concrete text artifact $filename through the global renderer path",
|
||||
async ({ filename, mimeType, content }) => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(content),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: `text-${filename}`,
|
||||
title: filename,
|
||||
mimeType,
|
||||
});
|
||||
const classification = classifyArtifact(
|
||||
artifact.mimeType,
|
||||
artifact.title,
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTestId("global-renderer");
|
||||
|
||||
expect(classification.type).toBe("text");
|
||||
expect(vi.mocked(globalRegistry.getRenderer)).toHaveBeenCalledWith(
|
||||
content,
|
||||
expect.objectContaining({
|
||||
filename,
|
||||
mimeType,
|
||||
}),
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
it("falls back to pre tag when no renderer matches", async () => {
|
||||
const { globalRegistry } = await import(
|
||||
"@/components/contextual/OutputRenderers"
|
||||
);
|
||||
const originalImpl = vi
|
||||
.mocked(globalRegistry.getRenderer)
|
||||
.getMockImplementation();
|
||||
|
||||
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("raw content fallback"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "fallback-pre-001",
|
||||
title: "unknown.txt",
|
||||
mimeType: "text/plain",
|
||||
});
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
const pre = container.querySelector("pre");
|
||||
expect(pre).toBeTruthy();
|
||||
expect(pre?.textContent).toBe("raw content fallback");
|
||||
});
|
||||
|
||||
// Restore
|
||||
if (originalImpl) {
|
||||
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -3,6 +3,7 @@ import { renderHook, waitFor, act } from "@testing-library/react";
|
||||
import {
|
||||
useArtifactContent,
|
||||
getCachedArtifactContent,
|
||||
clearContentCache,
|
||||
} from "../useArtifactContent";
|
||||
import type { ArtifactRef } from "../../../../store";
|
||||
import type { ArtifactClassification } from "../../helpers";
|
||||
@@ -33,6 +34,7 @@ function makeClassification(
|
||||
|
||||
describe("useArtifactContent", () => {
|
||||
beforeEach(() => {
|
||||
clearContentCache();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
@@ -44,6 +46,7 @@ describe("useArtifactContent", () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
clearContentCache();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
@@ -109,9 +112,12 @@ describe("useArtifactContent", () => {
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
});
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("404");
|
||||
expect(result.current.content).toBeNull();
|
||||
@@ -132,6 +138,176 @@ describe("useArtifactContent", () => {
|
||||
expect(getCachedArtifactContent("cache-test")).toBe("file content here");
|
||||
});
|
||||
|
||||
it("sets error on fetch failure for HTML artifacts (stale artifact)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "stale-html-artifact" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("404");
|
||||
expect(result.current.content).toBeNull();
|
||||
});
|
||||
|
||||
it("sets error on network failure", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockRejectedValue(new Error("Network error")),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "network-error-artifact" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("Network error");
|
||||
expect(result.current.content).toBeNull();
|
||||
});
|
||||
|
||||
it("retries transient HTML fetch failures before surfacing an error", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount < 3) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 503,
|
||||
headers: {
|
||||
get: () => "application/json",
|
||||
},
|
||||
json: () => Promise.resolve({ detail: "temporary upstream error" }),
|
||||
});
|
||||
}
|
||||
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("<html>ok now</html>"),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "transient-html-retry" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.content).toBe("<html>ok now</html>");
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(callCount).toBe(3);
|
||||
expect(result.current.error).toBeNull();
|
||||
});
|
||||
|
||||
it("surfaces backend error detail from JSON responses", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
headers: {
|
||||
get: () => "application/json",
|
||||
},
|
||||
json: () => Promise.resolve({ detail: "File not found" }),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "json-error-detail" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("404");
|
||||
expect(result.current.error).toContain("File not found");
|
||||
});
|
||||
|
||||
it("retry after 404 on HTML artifact clears cache and re-fetches", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
});
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("<html>recovered</html>"),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "retry-html-artifact" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.retry();
|
||||
});
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.content).toBe("<html>recovered</html>");
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toBeNull();
|
||||
});
|
||||
|
||||
it("retry clears cache and re-fetches", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
@@ -164,4 +340,162 @@ describe("useArtifactContent", () => {
|
||||
expect(result.current.content).toBe("response 2");
|
||||
});
|
||||
});
|
||||
|
||||
// ── Non-transient errors ──────────────────────────────────────────
|
||||
|
||||
it("rejects immediately on 403 without retrying", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 403,
|
||||
text: () => Promise.resolve("Forbidden"),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "forbidden-no-retry" });
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(callCount).toBe(1);
|
||||
expect(result.current.error).toContain("403");
|
||||
});
|
||||
|
||||
// ── Video skip-fetch ──────────────────────────────────────────────
|
||||
|
||||
it("skips fetch for video artifacts (like image)", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "video-skip",
|
||||
mimeType: "video/mp4",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
expect(result.current.isLoading).toBe(false);
|
||||
expect(result.current.content).toBeNull();
|
||||
expect(result.current.pdfUrl).toBeNull();
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── PDF error paths ───────────────────────────────────────────────
|
||||
|
||||
it("sets error on PDF fetch failure (non-2xx)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve("Server Error"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "pdf-error" });
|
||||
const classification = makeClassification({ type: "pdf" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("500");
|
||||
expect(result.current.pdfUrl).toBeNull();
|
||||
});
|
||||
|
||||
it("sets error on PDF network failure", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockRejectedValue(new Error("PDF network failure")),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "pdf-network-error" });
|
||||
const classification = makeClassification({ type: "pdf" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("PDF network failure");
|
||||
expect(result.current.pdfUrl).toBeNull();
|
||||
});
|
||||
|
||||
// ── LRU cache eviction ────────────────────────────────────────────
|
||||
|
||||
it("evicts oldest entry when cache exceeds 12 items", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation((url: string) => {
|
||||
const fileId = url.match(/files\/([^/]+)\/download/)?.[1] ?? "unknown";
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(`content-${fileId}`),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
// Fill the cache with 12 entries (cache max = 12)
|
||||
for (let i = 0; i < 12; i++) {
|
||||
const artifact = makeArtifact({
|
||||
id: `lru-${i}`,
|
||||
sourceUrl: `/api/proxy/api/workspace/files/lru-${i}/download`,
|
||||
});
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoading).toBe(false);
|
||||
});
|
||||
}
|
||||
|
||||
// All 12 should be cached
|
||||
expect(getCachedArtifactContent("lru-0")).toBe("content-lru-0");
|
||||
expect(getCachedArtifactContent("lru-11")).toBe("content-lru-11");
|
||||
|
||||
// Adding a 13th should evict lru-0 (the oldest)
|
||||
const artifact13 = makeArtifact({
|
||||
id: "lru-12",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/lru-12/download",
|
||||
});
|
||||
const { result: result13 } = renderHook(() =>
|
||||
useArtifactContent(artifact13, classification),
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(result13.current.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
expect(getCachedArtifactContent("lru-0")).toBeUndefined();
|
||||
expect(getCachedArtifactContent("lru-1")).toBe("content-lru-1");
|
||||
expect(getCachedArtifactContent("lru-12")).toBe("content-lru-12");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -85,4 +85,35 @@ describe("buildReactArtifactSrcDoc", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("box-sizing: border-box");
|
||||
});
|
||||
|
||||
it("supports a named previewProps export in the runtime", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("moduleExports.previewProps");
|
||||
expect(doc).toContain("React.createElement(Component, previewProps || {})");
|
||||
});
|
||||
|
||||
it("includes a helpful message for components that expect props", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("This component appears to expect props.");
|
||||
expect(doc).toContain("previewProps");
|
||||
});
|
||||
|
||||
it("checks componentExpectsProps on the raw component before wrapping", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("RawComponent.length > 0");
|
||||
expect(doc).toContain("wrapWithProviders(RawComponent");
|
||||
});
|
||||
|
||||
it("wrapWithProviders forwards props to the wrapped component", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("function WrappedArtifactPreview(props)");
|
||||
expect(doc).toContain("React.createElement(Component, props)");
|
||||
});
|
||||
|
||||
it("supports named exported components and provider wrappers in the runtime", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain('name.endsWith("Provider")');
|
||||
expect(doc).toContain("/^[A-Z]/.test(name)");
|
||||
expect(doc).toContain("wrapWithProviders");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -169,8 +169,8 @@ export function buildReactArtifactSrcDoc(
|
||||
return Component;
|
||||
}
|
||||
|
||||
return function WrappedArtifactPreview() {
|
||||
let tree = React.createElement(Component);
|
||||
return function WrappedArtifactPreview(props) {
|
||||
let tree = React.createElement(Component, props);
|
||||
|
||||
for (let i = providers.length - 1; i >= 0; i -= 1) {
|
||||
tree = React.createElement(providers[i], null, tree);
|
||||
@@ -180,6 +180,17 @@ export function buildReactArtifactSrcDoc(
|
||||
};
|
||||
}
|
||||
|
||||
function getPreviewProps(moduleExports) {
|
||||
if (
|
||||
moduleExports.previewProps &&
|
||||
typeof moduleExports.previewProps === "object"
|
||||
) {
|
||||
return moduleExports.previewProps;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function require(name) {
|
||||
if (name === "react") {
|
||||
return React;
|
||||
@@ -235,6 +246,11 @@ export function buildReactArtifactSrcDoc(
|
||||
|
||||
render() {
|
||||
if (this.state.error) {
|
||||
const propsHelp =
|
||||
this.props.componentExpectsProps && !this.props.hasPreviewProps
|
||||
? "\\n\\nThis component appears to expect props. Export a named previewProps object with sample values to render it in artifact preview."
|
||||
: "";
|
||||
|
||||
return React.createElement(
|
||||
"div",
|
||||
{
|
||||
@@ -249,7 +265,9 @@ export function buildReactArtifactSrcDoc(
|
||||
whiteSpace: "pre-wrap",
|
||||
},
|
||||
},
|
||||
this.state.error.stack || this.state.error.message || String(this.state.error),
|
||||
(this.state.error.stack ||
|
||||
this.state.error.message ||
|
||||
String(this.state.error)) + propsHelp,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -296,16 +314,19 @@ export function buildReactArtifactSrcDoc(
|
||||
moduleExports.App = executionResult.app;
|
||||
}
|
||||
|
||||
const Component = wrapWithProviders(
|
||||
getRenderableCandidate(moduleExports),
|
||||
moduleExports,
|
||||
);
|
||||
const RawComponent = getRenderableCandidate(moduleExports);
|
||||
const componentExpectsProps = RawComponent.length > 0;
|
||||
const Component = wrapWithProviders(RawComponent, moduleExports);
|
||||
const previewProps = getPreviewProps(moduleExports);
|
||||
|
||||
ReactDOM.createRoot(rootElement).render(
|
||||
React.createElement(
|
||||
PreviewErrorBoundary,
|
||||
null,
|
||||
React.createElement(Component),
|
||||
{
|
||||
componentExpectsProps: componentExpectsProps,
|
||||
hasPreviewProps: previewProps != null,
|
||||
},
|
||||
React.createElement(Component, previewProps || {}),
|
||||
),
|
||||
);
|
||||
} catch (error) {
|
||||
|
||||
@@ -48,4 +48,104 @@ describe("transpileReactArtifactSource", () => {
|
||||
expect(out).not.toContain(": string");
|
||||
expect(out).toContain("function greet(name)");
|
||||
});
|
||||
|
||||
it("transpiles a concrete props-based artifact with previewProps", async () => {
|
||||
const src = `
|
||||
import React, { FC, useState } from "react";
|
||||
|
||||
interface ArtifactFile {
|
||||
id: string;
|
||||
name: string;
|
||||
mimeType: string;
|
||||
url: string;
|
||||
sizeBytes: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
files: ArtifactFile[];
|
||||
onSelect: (file: ArtifactFile) => void;
|
||||
}
|
||||
|
||||
export const previewProps: Props = {
|
||||
files: [
|
||||
{
|
||||
id: "1",
|
||||
name: "report.png",
|
||||
mimeType: "image/png",
|
||||
url: "/report.png",
|
||||
sizeBytes: 2048,
|
||||
},
|
||||
],
|
||||
onSelect: () => {},
|
||||
};
|
||||
|
||||
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
|
||||
const handleClick = (file: ArtifactFile) => {
|
||||
setSelected(file.id);
|
||||
onSelect(file);
|
||||
};
|
||||
|
||||
return (
|
||||
<ul>
|
||||
{files.map((file) => (
|
||||
<li key={file.id} onClick={() => handleClick(file)}>
|
||||
<span>{selected === file.id ? "selected" : file.name}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
);
|
||||
};
|
||||
|
||||
export default ArtifactList;
|
||||
`;
|
||||
|
||||
const out = await transpileReactArtifactSource(src, "ArtifactList.tsx");
|
||||
|
||||
expect(out).toContain("exports.previewProps");
|
||||
expect(out).toContain("exports.default = ArtifactList");
|
||||
expect(out).toContain("useState");
|
||||
expect(out).not.toContain("interface Props");
|
||||
expect(out).not.toContain("interface ArtifactFile");
|
||||
});
|
||||
|
||||
it("transpiles a named export artifact without a default export", async () => {
|
||||
const src = `
|
||||
export function ResultsGrid() {
|
||||
return (
|
||||
<section>
|
||||
<h1>Results</h1>
|
||||
<p>Named export preview</p>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
`;
|
||||
|
||||
const out = await transpileReactArtifactSource(src, "ResultsGrid.tsx");
|
||||
|
||||
expect(out).toContain("exports.ResultsGrid = ResultsGrid");
|
||||
expect(out).toMatch(/\.createElement\(/);
|
||||
expect(out).not.toContain("<section>");
|
||||
});
|
||||
|
||||
it("transpiles a provider-wrapped artifact with separate provider and component exports", async () => {
|
||||
const src = `
|
||||
import React from "react";
|
||||
|
||||
export function DemoProvider({ children }: { children: React.ReactNode }) {
|
||||
return <div data-theme="demo">{children}</div>;
|
||||
}
|
||||
|
||||
export function DashboardCard() {
|
||||
return <main>Provider-backed preview</main>;
|
||||
}
|
||||
`;
|
||||
|
||||
const out = await transpileReactArtifactSource(src, "DashboardCard.tsx");
|
||||
|
||||
expect(out).toContain("exports.DemoProvider = DemoProvider");
|
||||
expect(out).toContain("exports.DashboardCard = DashboardCard");
|
||||
expect(out).not.toContain("React.ReactNode");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,12 +7,116 @@ import type { ArtifactClassification } from "../helpers";
|
||||
// Cap on cached text artifacts. Long sessions with many large artifacts
|
||||
// would otherwise hold every opened one in memory.
|
||||
const CONTENT_CACHE_MAX = 12;
|
||||
const CONTENT_FETCH_MAX_RETRIES = 2;
|
||||
const CONTENT_FETCH_RETRY_DELAY_MS = 500;
|
||||
|
||||
// Module-level LRU keyed by artifact id so a sibling action (e.g. Copy
|
||||
// in ArtifactPanelHeader) can read what the panel already fetched without
|
||||
// re-hitting the network.
|
||||
const contentCache = new Map<string, string>();
|
||||
|
||||
class ArtifactFetchError extends Error {}
|
||||
|
||||
function isTransientArtifactFetchStatus(status: number): boolean {
|
||||
return status === 408 || status === 429 || status >= 500;
|
||||
}
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
function getArtifactErrorMessage(body: unknown): string | null {
|
||||
if (typeof body === "string") {
|
||||
const trimmed = body.replace(/\s+/g, " ").trim();
|
||||
return trimmed || null;
|
||||
}
|
||||
|
||||
if (!body || typeof body !== "object") return null;
|
||||
|
||||
if (
|
||||
"detail" in body &&
|
||||
typeof body.detail === "string" &&
|
||||
body.detail.trim().length > 0
|
||||
) {
|
||||
return body.detail.trim();
|
||||
}
|
||||
|
||||
if (
|
||||
"error" in body &&
|
||||
typeof body.error === "string" &&
|
||||
body.error.trim().length > 0
|
||||
) {
|
||||
return body.error.trim();
|
||||
}
|
||||
|
||||
if (
|
||||
"detail" in body &&
|
||||
body.detail &&
|
||||
typeof body.detail === "object" &&
|
||||
"message" in body.detail &&
|
||||
typeof body.detail.message === "string" &&
|
||||
body.detail.message.trim().length > 0
|
||||
) {
|
||||
return body.detail.message.trim();
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async function parseArtifactFetchError(response: Response): Promise<string> {
|
||||
const prefix = `Failed to fetch: ${response.status}`;
|
||||
const contentType =
|
||||
response.headers?.get?.("content-type")?.toLowerCase() ?? "";
|
||||
|
||||
try {
|
||||
if (
|
||||
contentType.includes("application/json") &&
|
||||
typeof response.json === "function"
|
||||
) {
|
||||
const body = await response.json();
|
||||
const detail = getArtifactErrorMessage(body);
|
||||
return detail ? `${prefix} ${detail}` : prefix;
|
||||
}
|
||||
|
||||
if (typeof response.text === "function") {
|
||||
const text = await response.text();
|
||||
const detail = getArtifactErrorMessage(text);
|
||||
return detail ? `${prefix} ${detail}` : prefix;
|
||||
}
|
||||
} catch {
|
||||
return prefix;
|
||||
}
|
||||
|
||||
return prefix;
|
||||
}
|
||||
|
||||
async function fetchArtifactResponse(url: string): Promise<Response> {
|
||||
for (let attempt = 0; attempt <= CONTENT_FETCH_MAX_RETRIES; attempt++) {
|
||||
try {
|
||||
const response = await fetch(url);
|
||||
if (response.ok) return response;
|
||||
|
||||
if (
|
||||
!isTransientArtifactFetchStatus(response.status) ||
|
||||
attempt === CONTENT_FETCH_MAX_RETRIES
|
||||
) {
|
||||
throw new ArtifactFetchError(await parseArtifactFetchError(response));
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof ArtifactFetchError) throw error;
|
||||
if (attempt === CONTENT_FETCH_MAX_RETRIES) {
|
||||
throw error instanceof Error
|
||||
? error
|
||||
: new Error("Failed to fetch artifact");
|
||||
}
|
||||
}
|
||||
|
||||
await sleep(CONTENT_FETCH_RETRY_DELAY_MS);
|
||||
}
|
||||
|
||||
throw new Error("Failed to fetch artifact");
|
||||
}
|
||||
|
||||
export function getCachedArtifactContent(id: string): string | undefined {
|
||||
return contentCache.get(id);
|
||||
}
|
||||
@@ -64,7 +168,7 @@ export function useArtifactContent(
|
||||
}, [artifact.id, isLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
if (classification.type === "image") {
|
||||
if (classification.type === "image" || classification.type === "video") {
|
||||
setContent(null);
|
||||
setPdfUrl(null);
|
||||
setError(null);
|
||||
@@ -80,11 +184,8 @@ export function useArtifactContent(
|
||||
let objectUrl: string | null = null;
|
||||
setContent(null);
|
||||
setPdfUrl(null);
|
||||
fetch(artifact.sourceUrl)
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
|
||||
return res.blob();
|
||||
})
|
||||
fetchArtifactResponse(artifact.sourceUrl)
|
||||
.then((res) => res.blob())
|
||||
.then((blob) => {
|
||||
objectUrl = URL.createObjectURL(blob);
|
||||
if (cancelled) {
|
||||
@@ -121,11 +222,8 @@ export function useArtifactContent(
|
||||
cancelled = true;
|
||||
};
|
||||
}
|
||||
fetch(artifact.sourceUrl)
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
|
||||
return res.text();
|
||||
})
|
||||
fetchArtifactResponse(artifact.sourceUrl)
|
||||
.then((res) => res.text())
|
||||
.then((text) => {
|
||||
if (!cancelled) {
|
||||
if (cache.size >= CONTENT_CACHE_MAX) {
|
||||
|
||||
@@ -1,5 +1,31 @@
|
||||
import type { ArtifactRef } from "../../store";
|
||||
|
||||
const MAX_RETRIES = 2;
|
||||
const RETRY_DELAY_MS = 500;
|
||||
|
||||
function isTransientError(status: number): boolean {
|
||||
return status >= 500 || status === 408 || status === 429;
|
||||
}
|
||||
|
||||
class DownloadError extends Error {}
|
||||
|
||||
async function fetchWithRetry(url: string, retries: number): Promise<Response> {
|
||||
for (let attempt = 0; attempt <= retries; attempt++) {
|
||||
try {
|
||||
const res = await fetch(url);
|
||||
if (res.ok) return res;
|
||||
if (!isTransientError(res.status) || attempt === retries) {
|
||||
throw new DownloadError(`Download failed: ${res.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof DownloadError) throw error;
|
||||
if (attempt === retries) throw error;
|
||||
}
|
||||
await new Promise((r) => setTimeout(r, RETRY_DELAY_MS));
|
||||
}
|
||||
throw new Error("Unreachable");
|
||||
}
|
||||
|
||||
/**
|
||||
* Trigger a file download from an artifact URL.
|
||||
*
|
||||
@@ -7,26 +33,28 @@ import type { ArtifactRef } from "../../store";
|
||||
* ignores the `download` attribute on cross-origin responses (GCS signed
|
||||
* URLs), and some browsers require the anchor to be attached to the DOM
|
||||
* before `.click()` fires the download.
|
||||
*
|
||||
* Retries up to {@link MAX_RETRIES} times on transient server errors (5xx,
|
||||
* 408, 429) to handle intermittent proxy/GCS failures.
|
||||
*/
|
||||
export function downloadArtifact(artifact: ArtifactRef): Promise<void> {
|
||||
// Replace path separators, Windows-reserved chars, control chars, and
|
||||
// parent-dir sequences so the browser-assigned filename is safe to write
|
||||
// anywhere on the user's filesystem.
|
||||
const safeName =
|
||||
artifact.title
|
||||
.replace(/\.\./g, "_")
|
||||
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
|
||||
.replace(/^\.+/, "") || "download";
|
||||
return fetch(artifact.sourceUrl)
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Download failed: ${res.status}`);
|
||||
return res.blob();
|
||||
})
|
||||
const collapsedDots = artifact.title.replace(/\.\./g, "");
|
||||
const hasVisibleName = collapsedDots.replace(/^\.+/, "").length > 0;
|
||||
const safeName = artifact.title
|
||||
.replace(/\.\./g, "_")
|
||||
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
|
||||
.replace(/^\.+/, "");
|
||||
|
||||
return fetchWithRetry(artifact.sourceUrl, MAX_RETRIES)
|
||||
.then((res) => res.blob())
|
||||
.then((blob) => {
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = safeName;
|
||||
a.download = safeName && hasVisibleName ? safeName : "download";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
a.remove();
|
||||
|
||||
@@ -56,7 +56,7 @@ describe("classifyArtifact", () => {
|
||||
expect(classifyArtifact("application/octet-stream", "x").openable).toBe(
|
||||
false,
|
||||
);
|
||||
expect(classifyArtifact("video/mp4", "clip.mp4").openable).toBe(false);
|
||||
expect(classifyArtifact("audio/mpeg", "track.mp3").openable).toBe(false);
|
||||
});
|
||||
|
||||
it("defaults unknown extension+MIME to download-only (not text)", () => {
|
||||
@@ -76,4 +76,398 @@ describe("classifyArtifact", () => {
|
||||
const c = classifyArtifact("text/plain", "data.csv");
|
||||
expect(c.type).toBe("csv");
|
||||
});
|
||||
|
||||
it("classifies video/mp4 as video (previewable)", () => {
|
||||
const c = classifyArtifact("video/mp4", "clip.mp4");
|
||||
expect(c.type).toBe("video");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
it("classifies video/webm as video (previewable)", () => {
|
||||
const c = classifyArtifact("video/webm", "clip.webm");
|
||||
expect(c.type).toBe("video");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
// ── Extension coverage ────────────────────────────────────────────
|
||||
|
||||
it("routes .htm as html (not just .html)", () => {
|
||||
const c = classifyArtifact(null, "page.htm");
|
||||
expect(c.type).toBe("html");
|
||||
expect(c.hasSourceToggle).toBe(true);
|
||||
});
|
||||
|
||||
it("routes .json as json with source toggle", () => {
|
||||
const c = classifyArtifact(null, "config.json");
|
||||
expect(c.type).toBe("json");
|
||||
expect(c.hasSourceToggle).toBe(true);
|
||||
});
|
||||
|
||||
it("routes .txt as text", () => {
|
||||
expect(classifyArtifact(null, "notes.txt").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes .log as text", () => {
|
||||
expect(classifyArtifact(null, "server.log").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes .mdx as markdown", () => {
|
||||
expect(classifyArtifact(null, "docs.mdx").type).toBe("markdown");
|
||||
});
|
||||
|
||||
it("routes browser-safe video extensions to video", () => {
|
||||
for (const ext of [".mp4", ".webm", ".m4v"]) {
|
||||
const c = classifyArtifact(null, `clip${ext}`);
|
||||
expect(c.type).toBe("video");
|
||||
expect(c.openable).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it("keeps legacy or unsupported video extensions download-only", () => {
|
||||
for (const ext of [".ogg", ".mov", ".avi", ".mkv", ".flv", ".mpeg"]) {
|
||||
const c = classifyArtifact(null, `clip${ext}`);
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("routes all code extensions to code", () => {
|
||||
const codeExts = [
|
||||
"main.js",
|
||||
"app.ts",
|
||||
"theme.scss",
|
||||
"legacy.less",
|
||||
"schema.graphql",
|
||||
"query.gql",
|
||||
"api.proto",
|
||||
"main.dart",
|
||||
"lib.rb",
|
||||
"server.rs",
|
||||
"App.java",
|
||||
"main.c",
|
||||
"util.cpp",
|
||||
"header.h",
|
||||
"Program.cs",
|
||||
"index.php",
|
||||
"main.swift",
|
||||
"App.kt",
|
||||
"run.sh",
|
||||
"start.bash",
|
||||
"prompt.zsh",
|
||||
"config.toml",
|
||||
"settings.ini",
|
||||
"app.cfg",
|
||||
"query.sql",
|
||||
"analysis.r",
|
||||
"game.lua",
|
||||
"script.pl",
|
||||
"Calc.scala",
|
||||
];
|
||||
for (const file of codeExts) {
|
||||
expect(classifyArtifact(null, file).type).toBe("code");
|
||||
}
|
||||
});
|
||||
|
||||
it("routes config filenames and extensions to code", () => {
|
||||
const configFiles = [
|
||||
".env",
|
||||
".env.local",
|
||||
"app.properties",
|
||||
"service.conf",
|
||||
".gitignore",
|
||||
"Dockerfile",
|
||||
"Makefile",
|
||||
];
|
||||
|
||||
for (const file of configFiles) {
|
||||
expect(classifyArtifact(null, file).type).toBe("code");
|
||||
}
|
||||
});
|
||||
|
||||
it("routes .jsonl as code for now", () => {
|
||||
const c = classifyArtifact(null, "events.jsonl");
|
||||
expect(c.type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes .tsv as csv/spreadsheet", () => {
|
||||
const c = classifyArtifact(null, "table.tsv");
|
||||
expect(c.type).toBe("csv");
|
||||
expect(c.hasSourceToggle).toBe(true);
|
||||
});
|
||||
|
||||
it("routes .ics and .vcf as text", () => {
|
||||
expect(classifyArtifact(null, "calendar.ics").type).toBe("text");
|
||||
expect(classifyArtifact(null, "contact.vcf").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes all image extensions to image", () => {
|
||||
for (const ext of [".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".ico"]) {
|
||||
expect(classifyArtifact(null, `file${ext}`).type).toBe("image");
|
||||
}
|
||||
});
|
||||
|
||||
// ── MIME fallback coverage ────────────────────────────────────────
|
||||
|
||||
it("routes application/json MIME to json", () => {
|
||||
const c = classifyArtifact("application/json", "noext");
|
||||
expect(c.type).toBe("json");
|
||||
});
|
||||
|
||||
it("routes text/x-* MIME prefix to code", () => {
|
||||
expect(classifyArtifact("text/x-python", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("text/x-c", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("text/x-java-source", "noext").type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes react MIME types to react", () => {
|
||||
expect(classifyArtifact("text/jsx", "noext").type).toBe("react");
|
||||
expect(classifyArtifact("text/tsx", "noext").type).toBe("react");
|
||||
expect(classifyArtifact("application/jsx", "noext").type).toBe("react");
|
||||
expect(classifyArtifact("application/x-typescript-jsx", "noext").type).toBe(
|
||||
"react",
|
||||
);
|
||||
});
|
||||
|
||||
it("routes JavaScript/TypeScript MIME to code", () => {
|
||||
expect(classifyArtifact("application/javascript", "noext").type).toBe(
|
||||
"code",
|
||||
);
|
||||
expect(classifyArtifact("text/javascript", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("application/typescript", "noext").type).toBe(
|
||||
"code",
|
||||
);
|
||||
expect(classifyArtifact("text/typescript", "noext").type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes XML MIME to code", () => {
|
||||
expect(classifyArtifact("application/xml", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("text/xml", "noext").type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes text/x-markdown MIME to markdown", () => {
|
||||
expect(classifyArtifact("text/x-markdown", "noext").type).toBe("markdown");
|
||||
});
|
||||
|
||||
it("routes text/csv MIME to csv", () => {
|
||||
expect(classifyArtifact("text/csv", "noext").type).toBe("csv");
|
||||
});
|
||||
|
||||
it("routes TSV MIME to csv", () => {
|
||||
expect(classifyArtifact("text/tab-separated-values", "noext").type).toBe(
|
||||
"csv",
|
||||
);
|
||||
});
|
||||
|
||||
it("routes unknown text/* MIME to text (not download-only)", () => {
|
||||
expect(classifyArtifact("text/rtf", "noext").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes browser-safe image MIME types to image", () => {
|
||||
expect(classifyArtifact("image/avif", "noext").type).toBe("image");
|
||||
});
|
||||
|
||||
it("keeps unsupported image MIME types download-only", () => {
|
||||
for (const mime of [
|
||||
"image/tiff",
|
||||
"image/x-portable-pixmap",
|
||||
"image/x-portable-graymap",
|
||||
]) {
|
||||
const c = classifyArtifact(mime, "noext");
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("routes browser-safe video MIME types to video", () => {
|
||||
expect(classifyArtifact("video/mp4", "noext").type).toBe("video");
|
||||
expect(classifyArtifact("video/webm", "noext").type).toBe("video");
|
||||
});
|
||||
|
||||
it("keeps legacy or unsupported video MIME types download-only", () => {
|
||||
for (const mime of [
|
||||
"video/x-msvideo",
|
||||
"video/x-flv",
|
||||
"video/mpeg",
|
||||
"video/quicktime",
|
||||
"video/x-matroska",
|
||||
"video/ogg",
|
||||
]) {
|
||||
const c = classifyArtifact(mime, "noext");
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
// ── BINARY_MIMES coverage ────────────────────────────────────────
|
||||
|
||||
it("treats all BINARY_MIMES entries as download-only", () => {
|
||||
const binaryMimes = [
|
||||
"application/zip",
|
||||
"application/x-zip-compressed",
|
||||
"application/gzip",
|
||||
"application/x-tar",
|
||||
"application/x-rar-compressed",
|
||||
"application/x-7z-compressed",
|
||||
"application/octet-stream",
|
||||
"application/x-executable",
|
||||
"application/x-msdos-program",
|
||||
"application/vnd.microsoft.portable-executable",
|
||||
];
|
||||
for (const mime of binaryMimes) {
|
||||
const c = classifyArtifact(mime, "noext");
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
}
|
||||
});
|
||||
|
||||
it("treats audio/* MIME as download-only", () => {
|
||||
expect(classifyArtifact("audio/mpeg", "noext").openable).toBe(false);
|
||||
expect(classifyArtifact("audio/wav", "noext").openable).toBe(false);
|
||||
expect(classifyArtifact("audio/ogg", "noext").openable).toBe(false);
|
||||
});
|
||||
|
||||
// ── Size gate edge cases ──────────────────────────────────────────
|
||||
|
||||
it("does NOT gate files at exactly 10MB (boundary is >10MB)", () => {
|
||||
const tenMB = 10 * 1024 * 1024;
|
||||
const c = classifyArtifact("text/plain", "exact.txt", tenMB);
|
||||
expect(c.type).toBe("text");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
it("gates files at 10MB + 1 byte", () => {
|
||||
const overTenMB = 10 * 1024 * 1024 + 1;
|
||||
const c = classifyArtifact("text/plain", "big.txt", overTenMB);
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not gate when sizeBytes is 0", () => {
|
||||
const c = classifyArtifact("text/plain", "empty.txt", 0);
|
||||
expect(c.type).toBe("text");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
it("does not gate when sizeBytes is undefined", () => {
|
||||
const c = classifyArtifact("text/plain", "file.txt", undefined);
|
||||
expect(c.type).toBe("text");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
// ── Extension over MIME priority ──────────────────────────────────
|
||||
|
||||
it("extension wins over MIME for JSON (MIME says text, ext says json)", () => {
|
||||
const c = classifyArtifact("text/plain", "data.json");
|
||||
expect(c.type).toBe("json");
|
||||
});
|
||||
|
||||
it("extension wins over MIME for markdown", () => {
|
||||
const c = classifyArtifact("text/plain", "README.md");
|
||||
expect(c.type).toBe("markdown");
|
||||
});
|
||||
|
||||
// ── Null/missing inputs ───────────────────────────────────────────
|
||||
|
||||
it("handles null MIME with no filename as download-only", () => {
|
||||
const c = classifyArtifact(null, undefined);
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
it("handles null MIME with empty filename as download-only", () => {
|
||||
const c = classifyArtifact(null, "");
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
it("handles known config files with no extension", () => {
|
||||
const c = classifyArtifact(null, "Makefile");
|
||||
expect(c.type).toBe("code");
|
||||
});
|
||||
|
||||
// ── Exotic/compound extensions must NOT open the side panel ───────
|
||||
// These are real file types agents might produce. Every single one
|
||||
// must be download-only so we never try to render binary garbage.
|
||||
|
||||
it("does not open .tar.gz (compound extension takes last segment)", () => {
|
||||
// getExtension("archive.tar.gz") → ".gz" which is not in EXT_KIND
|
||||
const c = classifyArtifact(null, "archive.tar.gz");
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
it("does not open .tar.bz2", () => {
|
||||
const c = classifyArtifact(null, "archive.tar.bz2");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not open .tar.xz", () => {
|
||||
const c = classifyArtifact(null, "archive.tar.xz");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not open common binary formats", () => {
|
||||
const binaries = [
|
||||
"setup.exe",
|
||||
"library.dll",
|
||||
"image.iso",
|
||||
"installer.dmg",
|
||||
"package.deb",
|
||||
"package.rpm",
|
||||
"module.wasm",
|
||||
"Main.class",
|
||||
"module.pyc",
|
||||
"app.apk",
|
||||
"game.pak",
|
||||
"model.onnx",
|
||||
"weights.pt",
|
||||
"data.parquet",
|
||||
"archive.rar",
|
||||
"archive.7z",
|
||||
"disk.vhd",
|
||||
"disk.vmdk",
|
||||
"firmware.bin",
|
||||
"core.dump",
|
||||
"database.sqlite",
|
||||
"database.db",
|
||||
"index.idx",
|
||||
];
|
||||
for (const file of binaries) {
|
||||
const c = classifyArtifact(null, file);
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("does not open binary MIME types even with a misleading extension", () => {
|
||||
// Extension is unknown, MIME is binary
|
||||
const c = classifyArtifact("application/x-executable", "run.elf");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not open files with random/made-up extensions", () => {
|
||||
const weirdExts = [
|
||||
"output.xyz",
|
||||
"data.foo",
|
||||
"file.asdf",
|
||||
"thing.blargh",
|
||||
"result.out",
|
||||
"x.1234",
|
||||
];
|
||||
for (const file of weirdExts) {
|
||||
const c = classifyArtifact(null, file);
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
}
|
||||
});
|
||||
|
||||
it("does not open font files", () => {
|
||||
for (const file of ["sans.ttf", "serif.otf", "icon.woff", "icon.woff2"]) {
|
||||
expect(classifyArtifact(null, file).openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("does not open certificate/key files", () => {
|
||||
// .pem and .key have no extension mapping and null MIME → download-only
|
||||
for (const file of ["cert.pem", "server.key", "ca.crt", "id.p12"]) {
|
||||
expect(classifyArtifact(null, file).openable).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
FileText,
|
||||
Image,
|
||||
Table,
|
||||
VideoCamera,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { Icon } from "@phosphor-icons/react";
|
||||
|
||||
@@ -17,6 +18,7 @@ export interface ArtifactClassification {
|
||||
| "csv"
|
||||
| "json"
|
||||
| "image"
|
||||
| "video"
|
||||
| "pdf"
|
||||
| "text"
|
||||
| "download-only";
|
||||
@@ -38,6 +40,13 @@ const KIND: Record<string, ArtifactClassification> = {
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
},
|
||||
video: {
|
||||
type: "video",
|
||||
icon: VideoCamera,
|
||||
label: "Video",
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
},
|
||||
pdf: {
|
||||
type: "pdf",
|
||||
icon: FileText,
|
||||
@@ -113,8 +122,13 @@ const EXT_KIND: Record<string, string> = {
|
||||
".svg": "image",
|
||||
".bmp": "image",
|
||||
".ico": "image",
|
||||
".avif": "image",
|
||||
".mp4": "video",
|
||||
".webm": "video",
|
||||
".m4v": "video",
|
||||
".pdf": "pdf",
|
||||
".csv": "csv",
|
||||
".tsv": "csv",
|
||||
".html": "html",
|
||||
".htm": "html",
|
||||
".jsx": "react",
|
||||
@@ -122,11 +136,17 @@ const EXT_KIND: Record<string, string> = {
|
||||
".md": "markdown",
|
||||
".mdx": "markdown",
|
||||
".json": "json",
|
||||
".jsonl": "code",
|
||||
".txt": "text",
|
||||
".log": "text",
|
||||
".ics": "text",
|
||||
".vcf": "text",
|
||||
".env": "code",
|
||||
".gitignore": "code",
|
||||
// code extensions
|
||||
".js": "code",
|
||||
".ts": "code",
|
||||
".dart": "code",
|
||||
".py": "code",
|
||||
".rb": "code",
|
||||
".go": "code",
|
||||
@@ -142,11 +162,19 @@ const EXT_KIND: Record<string, string> = {
|
||||
".sh": "code",
|
||||
".bash": "code",
|
||||
".zsh": "code",
|
||||
".scss": "code",
|
||||
".sass": "code",
|
||||
".less": "code",
|
||||
".graphql": "code",
|
||||
".gql": "code",
|
||||
".proto": "code",
|
||||
".yml": "code",
|
||||
".yaml": "code",
|
||||
".toml": "code",
|
||||
".ini": "code",
|
||||
".cfg": "code",
|
||||
".conf": "code",
|
||||
".properties": "code",
|
||||
".sql": "code",
|
||||
".r": "code",
|
||||
".lua": "code",
|
||||
@@ -154,10 +182,16 @@ const EXT_KIND: Record<string, string> = {
|
||||
".scala": "code",
|
||||
};
|
||||
|
||||
const EXACT_FILENAME_KIND: Record<string, string> = {
|
||||
dockerfile: "code",
|
||||
makefile: "code",
|
||||
};
|
||||
|
||||
// Exact-match MIME → kind (fallback when extension doesn't match).
|
||||
const MIME_KIND: Record<string, string> = {
|
||||
"application/pdf": "pdf",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "csv",
|
||||
"text/html": "html",
|
||||
"text/jsx": "react",
|
||||
"text/tsx": "react",
|
||||
@@ -166,6 +200,9 @@ const MIME_KIND: Record<string, string> = {
|
||||
"text/markdown": "markdown",
|
||||
"text/x-markdown": "markdown",
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "code",
|
||||
"application/ndjson": "code",
|
||||
"application/jsonl": "code",
|
||||
"application/javascript": "code",
|
||||
"text/javascript": "code",
|
||||
"application/typescript": "code",
|
||||
@@ -182,11 +219,37 @@ const BINARY_MIMES = new Set([
|
||||
"application/x-rar-compressed",
|
||||
"application/x-7z-compressed",
|
||||
"application/octet-stream",
|
||||
"application/wasm",
|
||||
"application/x-executable",
|
||||
"application/x-msdos-program",
|
||||
"application/vnd.microsoft.portable-executable",
|
||||
]);
|
||||
|
||||
const PREVIEWABLE_IMAGE_MIMES = new Set([
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/svg+xml",
|
||||
"image/bmp",
|
||||
"image/x-icon",
|
||||
"image/vnd.microsoft.icon",
|
||||
"image/avif",
|
||||
]);
|
||||
|
||||
const PREVIEWABLE_VIDEO_MIMES = new Set([
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"video/x-m4v",
|
||||
]);
|
||||
|
||||
function getBasename(filename?: string): string {
|
||||
if (!filename) return "";
|
||||
const normalized = filename.replace(/\\/g, "/");
|
||||
const parts = normalized.split("/");
|
||||
return parts[parts.length - 1]?.toLowerCase() ?? "";
|
||||
}
|
||||
|
||||
function getExtension(filename?: string): string {
|
||||
if (!filename) return "";
|
||||
const lastDot = filename.lastIndexOf(".");
|
||||
@@ -202,24 +265,36 @@ export function classifyArtifact(
|
||||
// Size gate: >10MB is download-only regardless of type.
|
||||
if (sizeBytes && sizeBytes > TEN_MB) return KIND["download-only"];
|
||||
|
||||
const basename = getBasename(filename);
|
||||
const exactKind = EXACT_FILENAME_KIND[basename];
|
||||
if (exactKind) return KIND[exactKind];
|
||||
|
||||
if (basename === ".env" || basename.startsWith(".env.")) {
|
||||
return KIND.code;
|
||||
}
|
||||
|
||||
// Extension first (more reliable than MIME for AI-generated files).
|
||||
const ext = getExtension(filename);
|
||||
const ext = getExtension(basename);
|
||||
const extKind = EXT_KIND[ext];
|
||||
if (extKind) return KIND[extKind];
|
||||
|
||||
// MIME fallbacks.
|
||||
const mime = (mimeType ?? "").toLowerCase();
|
||||
if (mime.startsWith("image/")) return KIND.image;
|
||||
if (PREVIEWABLE_IMAGE_MIMES.has(mime)) return KIND.image;
|
||||
if (PREVIEWABLE_VIDEO_MIMES.has(mime)) return KIND.video;
|
||||
const mimeKind = MIME_KIND[mime];
|
||||
if (mimeKind) return KIND[mimeKind];
|
||||
if (mime.startsWith("text/x-")) return KIND.code;
|
||||
if (
|
||||
BINARY_MIMES.has(mime) ||
|
||||
mime.startsWith("audio/") ||
|
||||
mime.startsWith("video/")
|
||||
mime.startsWith("image/") ||
|
||||
mime.startsWith("video/") ||
|
||||
mime.startsWith("font/")
|
||||
) {
|
||||
return KIND["download-only"];
|
||||
}
|
||||
if (BINARY_MIMES.has(mime) || mime.startsWith("audio/")) {
|
||||
return KIND["download-only"];
|
||||
}
|
||||
if (mime.startsWith("text/")) return KIND.text;
|
||||
|
||||
// Unknown extension + unknown MIME: don't open — we can't safely assume
|
||||
|
||||
@@ -83,6 +83,7 @@ export function useArtifactPanel() {
|
||||
const canCopy =
|
||||
classification != null &&
|
||||
classification.type !== "image" &&
|
||||
classification.type !== "video" &&
|
||||
classification.type !== "download-only" &&
|
||||
classification.type !== "pdf";
|
||||
|
||||
|
||||
@@ -64,10 +64,7 @@ export const ChatContainer = ({
|
||||
// open state drive layout width; an artifact generated in a stale session
|
||||
// state would otherwise shrink the chat column with no panel rendered.
|
||||
const isArtifactOpen = isArtifactsEnabled && isArtifactPanelOpen;
|
||||
useAutoOpenArtifacts({
|
||||
messages: isArtifactsEnabled ? messages : [],
|
||||
sessionId,
|
||||
});
|
||||
useAutoOpenArtifacts({ sessionId });
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
status === "submitted" ||
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import { describe, expect, it, beforeEach, afterEach } from "vitest";
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { useAutoOpenArtifacts } from "../useAutoOpenArtifacts";
|
||||
import { useCopilotUIStore } from "../../../store";
|
||||
|
||||
// Capture the real store actions before any test can replace them.
|
||||
const realOpenArtifact = useCopilotUIStore.getState().openArtifact;
|
||||
const realResetArtifactPanel = useCopilotUIStore.getState().resetArtifactPanel;
|
||||
|
||||
function resetStore() {
|
||||
useCopilotUIStore.setState({
|
||||
openArtifact: realOpenArtifact,
|
||||
resetArtifactPanel: realResetArtifactPanel,
|
||||
artifactPanel: {
|
||||
isOpen: false,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
describe("useAutoOpenArtifacts", () => {
|
||||
beforeEach(resetStore);
|
||||
afterEach(resetStore);
|
||||
|
||||
it("does not auto-open artifacts on initial message load", () => {
|
||||
renderHook(() => useAutoOpenArtifacts({ sessionId: "session-1" }));
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("does not auto-open when rerendering within the same session", () => {
|
||||
const { rerender } = renderHook(
|
||||
({ sessionId }: { sessionId: string }) =>
|
||||
useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "session-1" } },
|
||||
);
|
||||
|
||||
rerender({ sessionId: "session-1" });
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("panel should fully reset when session changes", () => {
|
||||
const artifact = {
|
||||
id: "file1",
|
||||
title: "image.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file1/download",
|
||||
origin: "agent" as const,
|
||||
};
|
||||
useCopilotUIStore.getState().openArtifact(artifact);
|
||||
useCopilotUIStore.getState().openArtifact({
|
||||
...artifact,
|
||||
id: "file2",
|
||||
title: "second.png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file2/download",
|
||||
});
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
|
||||
|
||||
const { rerender } = renderHook(
|
||||
({ sessionId }: { sessionId: string }) =>
|
||||
useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "session-1" } },
|
||||
);
|
||||
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
|
||||
|
||||
rerender({ sessionId: "session-2" });
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -3,17 +3,19 @@ import { beforeEach, describe, expect, it } from "vitest";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { useAutoOpenArtifacts } from "./useAutoOpenArtifacts";
|
||||
|
||||
function assistantMessageWithText(id: string, text: string) {
|
||||
return {
|
||||
id,
|
||||
role: "assistant" as const,
|
||||
parts: [{ type: "text" as const, text }],
|
||||
};
|
||||
}
|
||||
|
||||
const A_ID = "11111111-0000-0000-0000-000000000000";
|
||||
const B_ID = "22222222-0000-0000-0000-000000000000";
|
||||
|
||||
function makeArtifact(id: string, title = `${id}.txt`) {
|
||||
return {
|
||||
id,
|
||||
title,
|
||||
mimeType: "text/plain",
|
||||
sourceUrl: `/api/proxy/api/workspace/files/${id}/download`,
|
||||
origin: "agent" as const,
|
||||
};
|
||||
}
|
||||
|
||||
function resetStore() {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
@@ -30,111 +32,60 @@ function resetStore() {
|
||||
describe("useAutoOpenArtifacts", () => {
|
||||
beforeEach(resetStore);
|
||||
|
||||
it("does NOT auto-open on the initial hydration of message list (baseline pass)", () => {
|
||||
const messages = [
|
||||
assistantMessageWithText("m1", `[a](workspace://${A_ID})`),
|
||||
];
|
||||
renderHook(() =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId: "s1" }),
|
||||
);
|
||||
// Initial run just records the baseline fingerprint; nothing opens.
|
||||
it("does not auto-open on initial render", () => {
|
||||
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("auto-opens when an existing assistant message adds a new artifact", () => {
|
||||
// 1st render: baseline with no artifact.
|
||||
const initial = [assistantMessageWithText("m1", "thinking...")];
|
||||
it("does not auto-open when rerendering within the same session", () => {
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{ initialProps: { messages: initial, sessionId: "s1" } },
|
||||
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "s1" } },
|
||||
);
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
|
||||
// 2nd render: same message id now contains an artifact link.
|
||||
act(() => {
|
||||
rerender({
|
||||
messages: [
|
||||
assistantMessageWithText("m1", `here: [A](workspace://${A_ID})`),
|
||||
],
|
||||
sessionId: "s1",
|
||||
});
|
||||
rerender({ sessionId: "s1" });
|
||||
});
|
||||
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("resets the panel state when sessionId changes", () => {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
|
||||
|
||||
const { rerender } = renderHook(
|
||||
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "s1" } },
|
||||
);
|
||||
|
||||
act(() => {
|
||||
rerender({ sessionId: "s2" });
|
||||
});
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(true);
|
||||
expect(s.activeArtifact?.id).toBe(A_ID);
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("does not re-open when the fingerprint hasn't changed", () => {
|
||||
const msg = assistantMessageWithText("m1", `[A](workspace://${A_ID})`);
|
||||
it("does not carry a stale back stack into the next session", () => {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
|
||||
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{ initialProps: { messages: [msg], sessionId: "s1" } },
|
||||
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "s1" } },
|
||||
);
|
||||
// Baseline captured; no open.
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
|
||||
// Rerender identical content: no change in fingerprint → no open.
|
||||
act(() => {
|
||||
rerender({ messages: [msg], sessionId: "s1" });
|
||||
rerender({ sessionId: "s2" });
|
||||
});
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("auto-opens when a brand-new assistant message arrives after the baseline is established", () => {
|
||||
// First render: one message without artifacts → establishes baseline.
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{
|
||||
initialProps: {
|
||||
messages: [assistantMessageWithText("m1", "plain")] as any,
|
||||
sessionId: "s1",
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact("c", "c.txt"));
|
||||
|
||||
// Second render: a *new* assistant message with an artifact. Baseline
|
||||
// is already set, so this should auto-open.
|
||||
act(() => {
|
||||
rerender({
|
||||
messages: [
|
||||
assistantMessageWithText("m1", "plain"),
|
||||
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
|
||||
] as any,
|
||||
sessionId: "s1",
|
||||
});
|
||||
});
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(true);
|
||||
expect(s.activeArtifact?.id).toBe(B_ID);
|
||||
});
|
||||
|
||||
it("resets hydration baseline when sessionId changes", () => {
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{
|
||||
initialProps: {
|
||||
messages: [
|
||||
assistantMessageWithText("m1", `[A](workspace://${A_ID})`),
|
||||
] as any,
|
||||
sessionId: "s1",
|
||||
},
|
||||
},
|
||||
);
|
||||
// Switch to a new session — the first pass on the new session should
|
||||
// NOT auto-open (it's a fresh hydration).
|
||||
act(() => {
|
||||
rerender({
|
||||
messages: [
|
||||
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
|
||||
] as any,
|
||||
sessionId: "s2",
|
||||
});
|
||||
});
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
expect(s.activeArtifact?.id).toBe("c");
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,91 +1,29 @@
|
||||
"use client";
|
||||
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useEffect, useRef } from "react";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { getMessageArtifacts } from "../ChatMessagesContainer/helpers";
|
||||
|
||||
function fingerprintArtifacts(artifacts: ArtifactRef[]): string {
|
||||
return artifacts
|
||||
.map((a) => `${a.id}:${a.title}:${a.mimeType ?? ""}:${a.sourceUrl}`)
|
||||
.join("|");
|
||||
}
|
||||
|
||||
interface UseAutoOpenArtifactsOptions {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
sessionId: string | null;
|
||||
}
|
||||
|
||||
export function useAutoOpenArtifacts({
|
||||
messages,
|
||||
sessionId,
|
||||
}: UseAutoOpenArtifactsOptions) {
|
||||
const openArtifact = useCopilotUIStore((state) => state.openArtifact);
|
||||
const messageFingerprintsRef = useRef<Map<string, string>>(new Map());
|
||||
const hasInitializedRef = useRef(false);
|
||||
const resetArtifactPanel = useCopilotUIStore(
|
||||
(state) => state.resetArtifactPanel,
|
||||
);
|
||||
const prevSessionIdRef = useRef(sessionId);
|
||||
|
||||
useEffect(() => {
|
||||
messageFingerprintsRef.current = new Map();
|
||||
hasInitializedRef.current = false;
|
||||
}, [sessionId]);
|
||||
const isSessionChange = prevSessionIdRef.current !== sessionId;
|
||||
prevSessionIdRef.current = sessionId;
|
||||
|
||||
useEffect(() => {
|
||||
if (messages.length === 0) {
|
||||
messageFingerprintsRef.current = new Map();
|
||||
return;
|
||||
// Artifact previews should open only from an explicit user click.
|
||||
// When the session changes, fully clear the panel state so stale
|
||||
// active artifacts and back-stack entries never bleed into the next chat.
|
||||
if (isSessionChange) {
|
||||
resetArtifactPanel();
|
||||
}
|
||||
|
||||
// Only scan messages whose fingerprint might have changed since the
|
||||
// last pass: that's the last assistant message (currently streaming)
|
||||
// plus any assistant message whose id isn't in the baseline yet.
|
||||
// This keeps the cost O(new+tail), not O(all messages), on every chunk.
|
||||
const previous = messageFingerprintsRef.current;
|
||||
const nextFingerprints = new Map<string, string>(previous);
|
||||
let nextArtifact: ArtifactRef | null = null;
|
||||
const lastAssistantIdx = (() => {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === "assistant") return i;
|
||||
}
|
||||
return -1;
|
||||
})();
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const message = messages[i];
|
||||
if (message.role !== "assistant") continue;
|
||||
const isTailAssistant = i === lastAssistantIdx;
|
||||
const isNewMessage = !previous.has(message.id);
|
||||
if (!isTailAssistant && !isNewMessage) continue;
|
||||
|
||||
const artifacts = getMessageArtifacts(message);
|
||||
const fingerprint = fingerprintArtifacts(artifacts);
|
||||
nextFingerprints.set(message.id, fingerprint);
|
||||
|
||||
if (!hasInitializedRef.current || fingerprint.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const previousFingerprint = previous.get(message.id) ?? "";
|
||||
if (previousFingerprint === fingerprint) continue;
|
||||
|
||||
nextArtifact = artifacts[artifacts.length - 1] ?? nextArtifact;
|
||||
}
|
||||
|
||||
// Drop entries for messages that no longer exist (e.g. history truncated).
|
||||
const liveIds = new Set(messages.map((m) => m.id));
|
||||
for (const id of nextFingerprints.keys()) {
|
||||
if (!liveIds.has(id)) nextFingerprints.delete(id);
|
||||
}
|
||||
|
||||
messageFingerprintsRef.current = nextFingerprints;
|
||||
|
||||
if (!hasInitializedRef.current) {
|
||||
hasInitializedRef.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (nextArtifact) {
|
||||
openArtifact(nextArtifact);
|
||||
}
|
||||
}, [messages, openArtifact]);
|
||||
}, [sessionId, resetArtifactPanel]);
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { DryRunToggleButton } from "./components/DryRunToggleButton";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { ModelToggleButton } from "./components/ModelToggleButton";
|
||||
import { ModeToggleButton } from "./components/ModeToggleButton";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
@@ -50,16 +51,22 @@ export function ChatInput({
|
||||
onDroppedFilesConsumed,
|
||||
hasSession = false,
|
||||
}: Props) {
|
||||
const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } =
|
||||
useCopilotUIStore();
|
||||
const {
|
||||
copilotChatMode,
|
||||
setCopilotChatMode,
|
||||
copilotLlmModel,
|
||||
setCopilotLlmModel,
|
||||
isDryRun,
|
||||
setIsDryRun,
|
||||
} = useCopilotUIStore();
|
||||
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const showDryRunToggle = showModeToggle;
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
function handleToggleMode() {
|
||||
const next =
|
||||
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
|
||||
setCopilotMode(next);
|
||||
copilotChatMode === "extended_thinking" ? "fast" : "extended_thinking";
|
||||
setCopilotChatMode(next);
|
||||
toast({
|
||||
title:
|
||||
next === "fast"
|
||||
@@ -72,6 +79,21 @@ export function ChatInput({
|
||||
});
|
||||
}
|
||||
|
||||
function handleToggleModel() {
|
||||
const next = copilotLlmModel === "advanced" ? "standard" : "advanced";
|
||||
setCopilotLlmModel(next);
|
||||
toast({
|
||||
title:
|
||||
next === "advanced"
|
||||
? "Switched to Advanced model"
|
||||
: "Switched to Standard model",
|
||||
description:
|
||||
next === "advanced"
|
||||
? "Using the highest-capability model."
|
||||
: "Using the balanced standard model.",
|
||||
});
|
||||
}
|
||||
|
||||
function handleToggleDryRun() {
|
||||
const next = !isDryRun;
|
||||
setIsDryRun(next);
|
||||
@@ -196,17 +218,28 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
{/* Mode and model are per-message settings sent with each stream request,
|
||||
so they can be freely changed between turns in an existing session.
|
||||
Hide only while actively streaming (too late to change for that turn). */}
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModeToggleButton
|
||||
mode={copilotMode}
|
||||
mode={copilotChatMode}
|
||||
onToggle={handleToggleMode}
|
||||
/>
|
||||
)}
|
||||
{showDryRunToggle && (!hasSession || isDryRun) && (
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModelToggleButton
|
||||
model={copilotLlmModel}
|
||||
onToggle={handleToggleModel}
|
||||
/>
|
||||
)}
|
||||
{/* DryRun button only on new chats: once a session exists its
|
||||
dry_run flag is locked and should be read from session metadata
|
||||
(sessionDryRun in useCopilotPage), not toggled here. The banner
|
||||
in CopilotPage.tsx reflects the actual session state. */}
|
||||
{showDryRunToggle && !hasSession && (
|
||||
<DryRunToggleButton
|
||||
isDryRun={isDryRun}
|
||||
isStreaming={isStreaming}
|
||||
readOnly={hasSession}
|
||||
onToggle={handleToggleDryRun}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -8,14 +8,23 @@ import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatInput } from "../ChatInput";
|
||||
|
||||
let mockCopilotMode = "extended_thinking";
|
||||
const mockSetCopilotMode = vi.fn((mode: string) => {
|
||||
const mockSetCopilotChatMode = vi.fn((mode: string) => {
|
||||
mockCopilotMode = mode;
|
||||
});
|
||||
|
||||
let mockCopilotLlmModel = "standard";
|
||||
const mockSetCopilotLlmModel = vi.fn((model: string) => {
|
||||
mockCopilotLlmModel = model;
|
||||
});
|
||||
|
||||
vi.mock("@/app/(platform)/copilot/store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
copilotMode: mockCopilotMode,
|
||||
setCopilotMode: mockSetCopilotMode,
|
||||
copilotChatMode: mockCopilotMode,
|
||||
setCopilotChatMode: mockSetCopilotChatMode,
|
||||
copilotLlmModel: mockCopilotLlmModel,
|
||||
setCopilotLlmModel: mockSetCopilotLlmModel,
|
||||
isDryRun: false,
|
||||
setIsDryRun: vi.fn(),
|
||||
initialPrompt: null,
|
||||
setInitialPrompt: vi.fn(),
|
||||
}),
|
||||
@@ -107,6 +116,7 @@ afterEach(() => {
|
||||
cleanup();
|
||||
vi.clearAllMocks();
|
||||
mockCopilotMode = "extended_thinking";
|
||||
mockCopilotLlmModel = "standard";
|
||||
});
|
||||
|
||||
describe("ChatInput mode toggle", () => {
|
||||
@@ -141,7 +151,7 @@ describe("ChatInput mode toggle", () => {
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
|
||||
expect(mockSetCopilotMode).toHaveBeenCalledWith("fast");
|
||||
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("fast");
|
||||
});
|
||||
|
||||
it("toggles from fast to extended_thinking on click", () => {
|
||||
@@ -149,7 +159,7 @@ describe("ChatInput mode toggle", () => {
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to extended thinking/i));
|
||||
expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking");
|
||||
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
|
||||
});
|
||||
|
||||
it("hides toggle button when streaming", () => {
|
||||
@@ -158,6 +168,15 @@ describe("ChatInput mode toggle", () => {
|
||||
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
|
||||
});
|
||||
|
||||
it("shows mode toggle when hasSession is true and not streaming", () => {
|
||||
// Mode is per-message — can be changed between turns even in an existing session.
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} hasSession />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (fast|extended thinking) mode/i),
|
||||
).not.toBeNull();
|
||||
});
|
||||
|
||||
it("exposes aria-pressed=true in extended_thinking mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
@@ -187,3 +206,93 @@ describe("ChatInput mode toggle", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("ChatInput model toggle", () => {
|
||||
it("renders model toggle button when flag is enabled", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not render model toggle when flag is disabled", () => {
|
||||
mockFlagValue = false;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|standard) model/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("toggles from standard to advanced on click", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "standard";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
|
||||
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced");
|
||||
});
|
||||
|
||||
it("toggles from advanced to standard on click", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
|
||||
});
|
||||
|
||||
it("hides model toggle when streaming", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|standard) model/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("shows model toggle when hasSession is true and not streaming", () => {
|
||||
// Model is per-message — can be changed between turns even in an existing session.
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} hasSession />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|standard) model/i),
|
||||
).not.toBeNull();
|
||||
});
|
||||
|
||||
it("hides dry-run toggle when hasSession is true", () => {
|
||||
// DryRun button is only for new chats — once a session exists its dry_run
|
||||
// flag is immutable and shown via the CopilotPage banner, not this button.
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} hasSession />);
|
||||
expect(screen.queryByLabelText(/test mode/i)).toBeNull();
|
||||
expect(screen.queryByLabelText(/enable test mode/i)).toBeNull();
|
||||
});
|
||||
|
||||
it("shows dry-run toggle when no session", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.getByLabelText(/test mode|enable test mode/i)).toBeTruthy();
|
||||
});
|
||||
|
||||
it("shows a toast when switching to advanced", async () => {
|
||||
const { toast } = await import("@/components/molecules/Toast/use-toast");
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "standard";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to advanced model/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("shows a toast when switching to standard", async () => {
|
||||
const { toast } = await import("@/components/molecules/Toast/use-toast");
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to standard model/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,42 +3,34 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flask } from "@phosphor-icons/react";
|
||||
|
||||
// This button is only rendered on NEW chats (no active session).
|
||||
// Once a session exists, it is hidden — the session's dry_run flag is
|
||||
// immutable and reflected in the banner in CopilotPage.tsx instead.
|
||||
// Do NOT add readOnly/hasSession handling here; hide it at the call site.
|
||||
interface Props {
|
||||
isDryRun: boolean;
|
||||
isStreaming: boolean;
|
||||
readOnly?: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
export function DryRunToggleButton({
|
||||
isDryRun,
|
||||
isStreaming,
|
||||
readOnly = false,
|
||||
onToggle,
|
||||
}: Props) {
|
||||
const isDisabled = isStreaming || readOnly;
|
||||
export function DryRunToggleButton({ isDryRun, onToggle }: Props) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isDryRun}
|
||||
disabled={isDisabled}
|
||||
onClick={readOnly ? undefined : onToggle}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isDryRun
|
||||
? "bg-amber-100 text-amber-900 hover:bg-amber-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
isDisabled && "cursor-default opacity-70",
|
||||
)}
|
||||
aria-label={isDryRun ? "Test mode active" : "Enable Test mode"}
|
||||
aria-label={
|
||||
isDryRun ? "Test mode active — click to disable" : "Enable Test mode"
|
||||
}
|
||||
title={
|
||||
readOnly
|
||||
? "Test mode active for this session"
|
||||
: isStreaming
|
||||
? "Cannot change mode while streaming"
|
||||
: isDryRun
|
||||
? "Test mode ON — click to disable"
|
||||
: "Enable Test mode — agents will run as dry-run"
|
||||
isDryRun
|
||||
? "Test mode ON — new chats run agents as simulation (click to disable)"
|
||||
: "Enable Test mode — new chats will run agents as simulation"
|
||||
}
|
||||
>
|
||||
<Flask size={14} />
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
import type { CopilotLlmModel } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
model: CopilotLlmModel;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
export function ModelToggleButton({ model, onToggle }: Props) {
|
||||
const isAdvanced = model === "advanced";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isAdvanced
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
|
||||
}
|
||||
title={
|
||||
isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
{isAdvanced && "Advanced"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { DryRunToggleButton } from "../DryRunToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
// DryRunToggleButton only appears on new chats (no active session).
|
||||
// It has no readOnly/isStreaming props — those scenarios are handled by hiding
|
||||
// the button entirely at the ChatInput level when hasSession is true.
|
||||
describe("DryRunToggleButton", () => {
|
||||
it("shows Test label when isDryRun is true", () => {
|
||||
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Test")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("shows no text label when isDryRun is false", () => {
|
||||
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
|
||||
expect(screen.queryByText("Test")).toBeNull();
|
||||
});
|
||||
|
||||
it("calls onToggle when clicked", () => {
|
||||
const onToggle = vi.fn();
|
||||
render(<DryRunToggleButton isDryRun={false} onToggle={onToggle} />);
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(onToggle).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("sets aria-pressed=true when isDryRun is true", () => {
|
||||
render(<DryRunToggleButton isDryRun={true} onToggle={vi.fn()} />);
|
||||
expect(screen.getByRole("button").getAttribute("aria-pressed")).toBe(
|
||||
"true",
|
||||
);
|
||||
});
|
||||
|
||||
it("sets aria-pressed=false when isDryRun is false", () => {
|
||||
render(<DryRunToggleButton isDryRun={false} onToggle={vi.fn()} />);
|
||||
expect(screen.getByRole("button").getAttribute("aria-pressed")).toBe(
|
||||
"false",
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,37 @@
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { ModelToggleButton } from "../ModelToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
describe("ModelToggleButton", () => {
|
||||
it("shows no text label when model is standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
expect(screen.queryByText("Standard")).toBeNull();
|
||||
expect(screen.queryByText("Advanced")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows Advanced label when model is advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Advanced")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("calls onToggle when clicked", () => {
|
||||
const onToggle = vi.fn();
|
||||
render(<ModelToggleButton model="standard" onToggle={onToggle} />);
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(onToggle).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("sets aria-pressed=false for standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Advanced model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("false");
|
||||
});
|
||||
|
||||
it("sets aria-pressed=true for advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Standard model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
});
|
||||
@@ -19,8 +19,16 @@ describe("formatResetTime", () => {
|
||||
});
|
||||
|
||||
it("returns formatted date when over 24 hours away", () => {
|
||||
const result = formatResetTime("2025-06-17T00:00:00Z", now);
|
||||
expect(result).toMatch(/Tue/);
|
||||
const resetsAt = "2025-06-17T00:00:00Z";
|
||||
const result = formatResetTime(resetsAt, now);
|
||||
const expected = new Date(resetsAt).toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
|
||||
expect(result).toBe(expected);
|
||||
});
|
||||
|
||||
it("accepts a Date object for resetsAt", () => {
|
||||
|
||||
@@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
export const ORIGINAL_TITLE = "AutoGPT";
|
||||
|
||||
/**
|
||||
@@ -50,6 +52,24 @@ export function parseSessionIDs(raw: string | null | undefined): Set<string> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the actual dry_run value for a session from the raw API response.
|
||||
* Returns true only when the session response is a 200 with metadata.dry_run === true.
|
||||
* Returns false for missing/non-200 responses so callers never show a stale
|
||||
* preference value when the real session state is unknown.
|
||||
*/
|
||||
export function resolveSessionDryRun(queryData: unknown): boolean {
|
||||
if (
|
||||
queryData == null ||
|
||||
typeof queryData !== "object" ||
|
||||
!("status" in queryData) ||
|
||||
(queryData as { status: unknown }).status !== 200
|
||||
)
|
||||
return false;
|
||||
const d = queryData as { data?: { metadata?: { dry_run?: unknown } } };
|
||||
return d.data?.metadata?.dry_run === true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether a refetchSession result indicates the backend still has an
|
||||
* active SSE stream for this session.
|
||||
@@ -154,7 +174,18 @@ export function shouldSuppressDuplicateSend(
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by content fingerprint.
|
||||
* Fire-and-forget: tell the backend to release XREAD listeners for a session.
|
||||
*
|
||||
* Called on session switch so the backend doesn't wait for its 5-10 s timeout
|
||||
* before cleaning up. Failures are silently ignored — the backend will
|
||||
* eventually clean up on its own.
|
||||
*/
|
||||
export function disconnectSessionStream(sessionId: string): void {
|
||||
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by consecutive content fingerprint.
|
||||
*
|
||||
* ID dedup catches exact duplicates within the same source.
|
||||
* Content dedup uses a composite key of `role + preceding-user-message-id +
|
||||
|
||||
@@ -99,6 +99,50 @@ describe("artifactPanel store actions", () => {
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("openArtifact does not resurrect a previously closed artifact into history", () => {
|
||||
const a = makeArtifact("a");
|
||||
const b = makeArtifact("b");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
useCopilotUIStore.getState().closeArtifactPanel();
|
||||
useCopilotUIStore.getState().openArtifact(b);
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(true);
|
||||
expect(s.activeArtifact?.id).toBe("b");
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("openArtifact ignores non-previewable artifacts", () => {
|
||||
const binary = {
|
||||
...makeArtifact("bin", "artifact.bin"),
|
||||
mimeType: "application/octet-stream",
|
||||
};
|
||||
|
||||
useCopilotUIStore.getState().openArtifact(binary);
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("resetArtifactPanel clears active artifact and history", () => {
|
||||
const a = makeArtifact("a");
|
||||
const b = makeArtifact("b");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
useCopilotUIStore.getState().openArtifact(b);
|
||||
useCopilotUIStore.getState().maximizeArtifactPanel();
|
||||
|
||||
useCopilotUIStore.getState().resetArtifactPanel();
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.isMinimized).toBe(false);
|
||||
expect(s.isMaximized).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("minimize/restore toggles isMinimized without touching activeArtifact", () => {
|
||||
const a = makeArtifact("a");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
@@ -138,4 +182,35 @@ describe("artifactPanel store actions", () => {
|
||||
expect(s.width).toBe(720);
|
||||
expect(s.isMaximized).toBe(false);
|
||||
});
|
||||
|
||||
it("history is capped at 25 entries (MAX_HISTORY)", () => {
|
||||
// Open 27 artifacts sequentially (A0..A26). History should never exceed 25.
|
||||
for (let i = 0; i < 27; i++) {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(`a${i}`));
|
||||
}
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.activeArtifact?.id).toBe("a26");
|
||||
expect(s.history.length).toBe(25);
|
||||
// The oldest entry (a0) should have been dropped; a1 is the earliest surviving.
|
||||
expect(s.history[0].id).toBe("a1");
|
||||
expect(s.history[24].id).toBe("a25");
|
||||
});
|
||||
|
||||
it("clearCopilotLocalData resets artifact panel to default", () => {
|
||||
const a = makeArtifact("a");
|
||||
const b = makeArtifact("b");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
useCopilotUIStore.getState().openArtifact(b);
|
||||
useCopilotUIStore.getState().maximizeArtifactPanel();
|
||||
|
||||
useCopilotUIStore.getState().clearCopilotLocalData();
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.isMinimized).toBe(false);
|
||||
expect(s.isMaximized).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
expect(s.width).toBe(600); // DEFAULT_PANEL_WIDTH
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { create } from "zustand";
|
||||
import { clearContentCache } from "./components/ArtifactPanel/components/useArtifactContent";
|
||||
import { classifyArtifact } from "./components/ArtifactPanel/helpers";
|
||||
import { ORIGINAL_TITLE, parseSessionIDs } from "./helpers";
|
||||
|
||||
export interface DeleteTarget {
|
||||
@@ -52,6 +53,9 @@ export const DEFAULT_PANEL_WIDTH = 600;
|
||||
/** Autopilot response mode. */
|
||||
export type CopilotMode = "extended_thinking" | "fast";
|
||||
|
||||
/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */
|
||||
export type CopilotLlmModel = "standard" | "advanced";
|
||||
|
||||
const isClient = typeof window !== "undefined";
|
||||
|
||||
function getPersistedWidth(): number {
|
||||
@@ -92,6 +96,10 @@ function persistCompletedSessions(ids: Set<string>) {
|
||||
}
|
||||
}
|
||||
|
||||
function isPreviewableArtifact(ref: ArtifactRef): boolean {
|
||||
return classifyArtifact(ref.mimeType, ref.title, ref.sizeBytes).openable;
|
||||
}
|
||||
|
||||
interface CopilotUIState {
|
||||
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
|
||||
initialPrompt: string | null;
|
||||
@@ -121,6 +129,7 @@ interface CopilotUIState {
|
||||
artifactPanel: ArtifactPanelState;
|
||||
openArtifact: (ref: ArtifactRef) => void;
|
||||
closeArtifactPanel: () => void;
|
||||
resetArtifactPanel: () => void;
|
||||
minimizeArtifactPanel: () => void;
|
||||
maximizeArtifactPanel: () => void;
|
||||
restoreArtifactPanel: () => void;
|
||||
@@ -128,8 +137,12 @@ interface CopilotUIState {
|
||||
goBackArtifact: () => void;
|
||||
|
||||
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
|
||||
copilotMode: CopilotMode;
|
||||
setCopilotMode: (mode: CopilotMode) => void;
|
||||
copilotChatMode: CopilotMode;
|
||||
setCopilotChatMode: (mode: CopilotMode) => void;
|
||||
|
||||
/** Model tier: 'standard' (default) or 'advanced' (highest-capability). */
|
||||
copilotLlmModel: CopilotLlmModel;
|
||||
setCopilotLlmModel: (model: CopilotLlmModel) => void;
|
||||
|
||||
/** Developer dry-run mode: sessions created with dry_run=true. */
|
||||
isDryRun: boolean;
|
||||
@@ -203,14 +216,20 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
},
|
||||
openArtifact: (ref) =>
|
||||
set((state) => {
|
||||
if (!isPreviewableArtifact(ref)) return state;
|
||||
|
||||
const { activeArtifact, history: prevHistory } = state.artifactPanel;
|
||||
const topOfHistory = prevHistory[prevHistory.length - 1];
|
||||
const isReturningToTop = topOfHistory?.id === ref.id;
|
||||
const shouldPushHistory =
|
||||
state.artifactPanel.isOpen &&
|
||||
activeArtifact != null &&
|
||||
activeArtifact.id !== ref.id;
|
||||
const MAX_HISTORY = 25;
|
||||
const history = isReturningToTop
|
||||
? prevHistory.slice(0, -1)
|
||||
: activeArtifact && activeArtifact.id !== ref.id
|
||||
? [...prevHistory, activeArtifact].slice(-MAX_HISTORY)
|
||||
: shouldPushHistory
|
||||
? [...prevHistory, activeArtifact!].slice(-MAX_HISTORY)
|
||||
: prevHistory;
|
||||
return {
|
||||
artifactPanel: {
|
||||
@@ -231,6 +250,17 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
history: [],
|
||||
},
|
||||
})),
|
||||
resetArtifactPanel: () =>
|
||||
set((state) => ({
|
||||
artifactPanel: {
|
||||
...state.artifactPanel,
|
||||
isOpen: false,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
})),
|
||||
minimizeArtifactPanel: () =>
|
||||
set((state) => ({
|
||||
artifactPanel: { ...state.artifactPanel, isMinimized: true },
|
||||
@@ -275,9 +305,22 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
};
|
||||
}),
|
||||
|
||||
copilotMode: "extended_thinking",
|
||||
setCopilotMode: (mode) => {
|
||||
set({ copilotMode: mode });
|
||||
copilotChatMode: (() => {
|
||||
const saved = isClient ? storage.get(Key.COPILOT_MODE) : null;
|
||||
return saved === "fast" ? "fast" : "extended_thinking";
|
||||
})(),
|
||||
setCopilotChatMode: (mode) => {
|
||||
storage.set(Key.COPILOT_MODE, mode);
|
||||
set({ copilotChatMode: mode });
|
||||
},
|
||||
|
||||
copilotLlmModel: (() => {
|
||||
const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null;
|
||||
return saved === "advanced" ? "advanced" : "standard";
|
||||
})(),
|
||||
setCopilotLlmModel: (model) => {
|
||||
storage.set(Key.COPILOT_MODEL, model);
|
||||
set({ copilotLlmModel: model });
|
||||
},
|
||||
|
||||
isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true",
|
||||
@@ -299,6 +342,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH);
|
||||
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
|
||||
storage.clean(Key.COPILOT_DRY_RUN);
|
||||
storage.clean(Key.COPILOT_MODE);
|
||||
storage.clean(Key.COPILOT_MODEL);
|
||||
set({
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
@@ -311,7 +356,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
copilotMode: "extended_thinking",
|
||||
copilotChatMode: "extended_thinking",
|
||||
copilotLlmModel: "standard",
|
||||
isDryRun: false,
|
||||
});
|
||||
if (isClient) {
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { BlockOutputResponse } from "@/app/api/__generated__/models/blockOutputResponse";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { isWorkspaceURI, parseWorkspaceURI } from "@/lib/workspace-uri";
|
||||
import { resolveForRenderer } from "@/app/(platform)/copilot/tools/ViewAgentOutput/ViewAgentOutput";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCard,
|
||||
@@ -24,28 +22,6 @@ interface Props {
|
||||
|
||||
const COLLAPSED_LIMIT = 3;
|
||||
|
||||
function resolveForRenderer(value: unknown): {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
} {
|
||||
if (!isWorkspaceURI(value)) return { value };
|
||||
|
||||
const parsed = parseWorkspaceURI(value);
|
||||
if (!parsed) return { value };
|
||||
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
|
||||
const url = `/api/proxy${apiPath}`;
|
||||
|
||||
const metadata: OutputMetadata = {};
|
||||
if (parsed.mimeType) {
|
||||
metadata.mimeType = parsed.mimeType;
|
||||
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
|
||||
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
|
||||
}
|
||||
|
||||
return { value: url, metadata };
|
||||
}
|
||||
|
||||
function RenderOutputValue({ value }: { value: unknown }) {
|
||||
const resolved = resolveForRenderer(value);
|
||||
const renderer = globalRegistry.getRenderer(
|
||||
@@ -63,16 +39,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for audio workspace refs
|
||||
if (
|
||||
isWorkspaceURI(value) &&
|
||||
resolved.metadata?.mimeType?.startsWith("audio/")
|
||||
) {
|
||||
return (
|
||||
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import type { ToolUIPart } from "ai";
|
||||
import React from "react";
|
||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
@@ -47,7 +46,7 @@ interface Props {
|
||||
part: ViewAgentOutputToolPart;
|
||||
}
|
||||
|
||||
function resolveForRenderer(value: unknown): {
|
||||
export function resolveForRenderer(value: unknown): {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
} {
|
||||
@@ -56,17 +55,17 @@ function resolveForRenderer(value: unknown): {
|
||||
const parsed = parseWorkspaceURI(value);
|
||||
if (!parsed) return { value };
|
||||
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
|
||||
const url = `/api/proxy${apiPath}`;
|
||||
|
||||
// Pass workspace URIs through to the registry unchanged.
|
||||
// WorkspaceFileRenderer (priority 50) matches workspace:// URIs and
|
||||
// handles URL building, loading skeletons, and error states internally.
|
||||
// Previously this converted to a proxy URL which bypassed
|
||||
// WorkspaceFileRenderer, causing ImageRenderer (bare <img>) to match.
|
||||
const metadata: OutputMetadata = {};
|
||||
if (parsed.mimeType) {
|
||||
metadata.mimeType = parsed.mimeType;
|
||||
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
|
||||
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
|
||||
}
|
||||
|
||||
return { value: url, metadata };
|
||||
return { value, metadata };
|
||||
}
|
||||
|
||||
function RenderOutputValue({ value }: { value: unknown }) {
|
||||
@@ -86,16 +85,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for audio workspace refs
|
||||
if (
|
||||
isWorkspaceURI(value) &&
|
||||
resolved.metadata?.mimeType?.startsWith("audio/")
|
||||
) {
|
||||
return (
|
||||
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { resolveForRenderer } from "../ViewAgentOutput";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
|
||||
describe("resolveForRenderer", () => {
|
||||
it("preserves workspace image URI for the registry to handle", () => {
|
||||
const result = resolveForRenderer("workspace://abc123#image/png");
|
||||
expect(String(result.value)).toMatch(/^workspace:\/\//);
|
||||
expect(result.metadata?.mimeType).toBe("image/png");
|
||||
});
|
||||
|
||||
it("preserves workspace video URI for the registry to handle", () => {
|
||||
const result = resolveForRenderer("workspace://vid456#video/mp4");
|
||||
expect(String(result.value)).toMatch(/^workspace:\/\//);
|
||||
expect(result.metadata?.mimeType).toBe("video/mp4");
|
||||
});
|
||||
|
||||
it("passes non-workspace values through unchanged", () => {
|
||||
const result = resolveForRenderer("just a string");
|
||||
expect(result.value).toBe("just a string");
|
||||
expect(result.metadata).toBeUndefined();
|
||||
});
|
||||
|
||||
it("passes non-string values through unchanged", () => {
|
||||
const obj = { foo: "bar" };
|
||||
const result = resolveForRenderer(obj);
|
||||
expect(result.value).toBe(obj);
|
||||
expect(result.metadata).toBeUndefined();
|
||||
});
|
||||
|
||||
it("workspace image URIs match WorkspaceFileRenderer with loading/error states", () => {
|
||||
// WorkspaceFileRenderer (priority 50) should handle workspace:// URIs
|
||||
// since resolveForRenderer no longer pre-converts them to proxy URLs.
|
||||
const resolved = resolveForRenderer("workspace://abc123#image/png");
|
||||
const renderer = globalRegistry.getRenderer(
|
||||
resolved.value,
|
||||
resolved.metadata,
|
||||
);
|
||||
expect(renderer).toBeDefined();
|
||||
expect(renderer!.name).toBe("WorkspaceFileRenderer");
|
||||
});
|
||||
|
||||
it("workspace video URIs match WorkspaceFileRenderer", () => {
|
||||
const resolved = resolveForRenderer("workspace://vid456#video/mp4");
|
||||
const renderer = globalRegistry.getRenderer(
|
||||
resolved.value,
|
||||
resolved.metadata,
|
||||
);
|
||||
expect(renderer).toBeDefined();
|
||||
expect(renderer!.name).toBe("WorkspaceFileRenderer");
|
||||
});
|
||||
});
|
||||
@@ -10,6 +10,7 @@ import { useQueryClient } from "@tanstack/react-query";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import { convertChatSessionMessagesToUiMessages } from "./helpers/convertChatSessionToUiMessages";
|
||||
import { resolveSessionDryRun } from "./helpers";
|
||||
|
||||
interface UseChatSessionOptions {
|
||||
dryRun?: boolean;
|
||||
@@ -163,6 +164,18 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
? ((sessionQuery.data.data.messages ?? []) as unknown[])
|
||||
: [];
|
||||
|
||||
// The actual dry_run value stored in the session's metadata, read directly
|
||||
// from the API response. This reflects what the session was ACTUALLY created
|
||||
// with — not the user's current UI preference (isDryRun store).
|
||||
//
|
||||
// Design intent: the global isDryRun store is only used when creating NEW
|
||||
// sessions. Once a session exists, its dry_run flag is immutable and should
|
||||
// be read from here rather than from the store, which may have changed.
|
||||
const sessionDryRun = useMemo(
|
||||
() => resolveSessionDryRun(sessionQuery.data),
|
||||
[sessionQuery.data],
|
||||
);
|
||||
|
||||
return {
|
||||
sessionId,
|
||||
setSessionId,
|
||||
@@ -177,5 +190,6 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) {
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
refetchSession: sessionQuery.refetch,
|
||||
sessionDryRun,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -42,7 +42,8 @@ export function useCopilotPage() {
|
||||
setSessionToDelete,
|
||||
isDrawerOpen,
|
||||
setDrawerOpen,
|
||||
copilotMode,
|
||||
copilotChatMode,
|
||||
copilotLlmModel,
|
||||
isDryRun,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
@@ -60,6 +61,7 @@ export function useCopilotPage() {
|
||||
createSession,
|
||||
isCreatingSession,
|
||||
refetchSession,
|
||||
sessionDryRun,
|
||||
} = useChatSession({ dryRun: isDryRun });
|
||||
|
||||
const {
|
||||
@@ -78,7 +80,8 @@ export function useCopilotPage() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
|
||||
copilotMode: isModeToggleEnabled ? copilotChatMode : undefined,
|
||||
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
|
||||
});
|
||||
|
||||
const { olderMessages, hasMore, isLoadingMore, loadMore } =
|
||||
@@ -416,6 +419,11 @@ export function useCopilotPage() {
|
||||
rateLimitMessage,
|
||||
dismissRateLimit,
|
||||
// Dry run dev toggle
|
||||
// isDryRun = global preference for NEW sessions (from localStorage).
|
||||
// sessionDryRun = actual dry_run value of the CURRENT session (from API).
|
||||
// Use isDryRun to configure future sessions; use sessionDryRun to display
|
||||
// the current session's simulation state (banner, indicators).
|
||||
isDryRun,
|
||||
sessionDryRun,
|
||||
};
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user