Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into fix/openrouter-null-cache-tokens

This commit is contained in:
majdyz
2026-04-15 14:49:00 +07:00
149 changed files with 13948 additions and 4640 deletions

View File

@@ -48,14 +48,15 @@ git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
For each changed file, determine:
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
**Priority order:**
1. Pages with new/changed data fetching or user interactions
2. Components with complex internal logic (modals, forms, wizards)
3. Hooks with non-trivial business logic
3. Shared hooks with standalone business logic when UI-level coverage is impractical
4. Pure helper functions
Skip: styling-only changes, type-only changes, config changes.
@@ -163,6 +164,7 @@ describe("LibraryPage", () => {
- Use `waitFor` when asserting side effects or state changes after interactions
- Import `fireEvent` or `userEvent` from the test-utils for interactions
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
- Keep tests focused: one behavior per test
- Use descriptive test names that read like sentences
@@ -190,9 +192,7 @@ import { http, HttpResponse } from "msw";
server.use(
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
return HttpResponse.json({
agents: [
{ id: "1", name: "Test Agent", description: "A test agent" },
],
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
});
}),
@@ -211,6 +211,7 @@ pnpm test:unit --reporter=verbose
```
If tests fail:
1. Read the error output carefully
2. Fix the test (not the source code, unless there is a genuine bug)
3. Re-run until all pass

View File

@@ -160,6 +160,7 @@ jobs:
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -288,6 +289,14 @@ jobs:
cache: "pnpm"
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
- name: Set up tests - Cache Playwright browsers
uses: actions/cache@v5
with:
path: ~/.cache/ms-playwright
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
restore-keys: |
playwright-${{ runner.os }}-
- name: Copy source maps from Docker for E2E coverage
run: |
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
@@ -299,8 +308,8 @@ jobs:
- name: Set up tests - Install browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Run Playwright tests
run: pnpm test:no-build
- name: Run Playwright E2E suite
run: pnpm test:e2e:no-build
continue-on-error: false
- name: Upload E2E coverage to Codecov

1
.gitignore vendored
View File

@@ -194,3 +194,4 @@ test.db
.next
# Implementation plans (generated by AI agents)
plans/
.claude/worktrees/

View File

@@ -0,0 +1,166 @@
{
"id": "858e2226-e047-4d19-a832-3be4a134d155",
"version": 2,
"is_active": true,
"name": "Calculator agent",
"description": "",
"instructions": null,
"recommended_schedule_cron": null,
"forked_from_id": null,
"forked_from_version": null,
"user_id": "",
"created_at": "2026-04-13T03:45:11.241Z",
"nodes": [
{
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
"input_default": {
"name": "Input",
"secret": false,
"advanced": false
},
"metadata": {
"position": {
"x": -188.2244873046875,
"y": 95
}
},
"input_links": [],
"output_links": [
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
},
{
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
"input_default": {
"name": "Output",
"secret": false,
"advanced": false,
"escape_html": false
},
"metadata": {
"position": {
"x": 825.198974609375,
"y": 123.75
}
},
"input_links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
}
],
"output_links": [],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
},
{
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
"input_default": {
"b": 34,
"operation": "Add",
"round_result": false
},
"metadata": {
"position": {
"x": 323.0255126953125,
"y": 121.25
}
},
"input_links": [
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"output_links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
}
],
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
"graph_version": 2,
"webhook_id": null
}
],
"links": [
{
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
"source_name": "result",
"sink_name": "value",
"is_static": false
},
{
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
"source_name": "result",
"sink_name": "a",
"is_static": true
}
],
"sub_graphs": [],
"input_schema": {
"type": "object",
"properties": {
"Input": {
"advanced": false,
"secret": false,
"title": "Input"
}
},
"required": [
"Input"
]
},
"output_schema": {
"type": "object",
"properties": {
"Output": {
"advanced": false,
"secret": false,
"title": "Output"
}
},
"required": [
"Output"
]
},
"has_external_trigger": false,
"has_human_in_the_loop": false,
"has_sensitive_action": false,
"trigger_setup_info": null,
"credentials_input_schema": {
"type": "object",
"properties": {},
"required": []
}
}

View File

@@ -25,6 +25,7 @@ from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
is_credentials_field_name,
)
@@ -43,7 +44,7 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, NodeExecutionStats
from backend.data.model import ContributorDetails
from ..data.graph import Link
@@ -420,6 +421,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
"""Return extra runtime cost to charge after this block run completes.
Called by the executor after a block finishes with COMPLETED status.
The return value is the number of additional base-cost credits to
charge beyond the single credit already collected by charge_usage
at the start of execution. Defaults to 0 (no extra charges).
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
calls within one run and should be billed per call.
"""
return 0
def __init__(
self,
id: str = "",
@@ -455,8 +469,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
disabled: If the block is disabled, it will not be available for execution.
static_output: Whether the output links of the block are static by default.
"""
from backend.data.model import NodeExecutionStats
self.id = id
self.input_schema = input_schema
self.output_schema = output_schema
@@ -474,7 +486,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.is_sensitive_action = is_sensitive_action
# Read from ClassVar set by initialize_blocks()
self.optimized_description: str | None = type(self)._optimized_description
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -554,7 +566,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
self.execution_stats += stats
return self.execution_stats

View File

@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import InsufficientBalanceError
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
from backend.util.security import SENSITIVE_FIELD_NAMES
from backend.util.tool_call_loop import (
@@ -364,10 +365,31 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
class OrchestratorBlock(Block):
"""A block that uses a language model to orchestrate tool calls.
Supports both single-shot and iterative agent mode execution.
**InsufficientBalanceError propagation contract**: ``InsufficientBalanceError``
(IBE) must always re-raise through every ``except`` block in this class.
Swallowing IBE would let the agent loop continue with unpaid work. Every
exception handler that catches ``Exception`` includes an explicit IBE
re-raise carve-out for this reason.
"""
A block that uses a language model to orchestrate tool calls, supporting both
single-shot and iterative agent mode execution.
"""
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
"""Charge one extra runtime cost per LLM call beyond the first.
In agent mode each iteration makes one LLM call. The first is already
covered by charge_usage(); this returns the number of additional
credits so the executor can bill the remaining calls post-completion.
SDK-mode exemption: when the block runs via _execute_tools_sdk_mode,
the SDK manages its own conversation loop and only exposes aggregate
usage. We hardcode llm_call_count=1 there (the SDK does not report a
per-turn call count), so this method always returns 0 for SDK-mode
executions. Per-iteration billing does not apply to SDK mode.
"""
return max(0, execution_stats.llm_call_count - 1)
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
@@ -1077,7 +1099,10 @@ class OrchestratorBlock(Block):
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
if node_exec_result is None:
raise RuntimeError(
f"upsert_execution_input returned None for node {sink_node_id}"
)
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
@@ -1112,15 +1137,86 @@ class OrchestratorBlock(Block):
task=node_exec_future,
)
# Execute the node directly since we're in the Orchestrator context
node_exec_future.set_result(
await execution_processor.on_node_execution(
# Execute the node directly since we're in the Orchestrator context.
# Wrap in try/except so the future is always resolved, even on
# error — an unresolved Future would block anything awaiting it.
#
# on_node_execution is decorated with @async_error_logged(swallow=True),
# which catches BaseException and returns None rather than raising.
# Treat a None return as a failure: set_exception so the future
# carries an error state rather than a None result, and return an
# error response so the LLM knows the tool failed.
try:
tool_node_stats = await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
if tool_node_stats is None:
nil_err = RuntimeError(
f"on_node_execution returned None for node {sink_node_id} "
"(error was swallowed by @async_error_logged)"
)
node_exec_future.set_exception(nil_err)
resp = _create_tool_response(
tool_call.id,
"Tool execution returned no result",
responses_api=responses_api,
)
resp["_is_error"] = True
return resp
node_exec_future.set_result(tool_node_stats)
except Exception as exec_err:
node_exec_future.set_exception(exec_err)
raise
# Charge user credits AFTER successful tool execution. Tools
# spawned by the orchestrator bypass the main execution queue
# (where _charge_usage is called), so we must charge here to
# avoid free tool execution. Charging post-completion (vs.
# pre-execution) avoids billing users for failed tool calls.
# Skipped for dry runs.
#
# `error is None` intentionally excludes both Exception and
# BaseException subclasses (e.g. CancelledError) so cancelled
# or terminated tool runs are not billed.
#
# Billing errors (including non-balance exceptions) are kept
# in a separate try/except so they are never silently swallowed
# by the generic tool-error handler below.
if (
not execution_params.execution_context.dry_run
and tool_node_stats.error is None
):
try:
tool_cost, _ = await execution_processor.charge_node_usage(
node_exec_entry,
)
except InsufficientBalanceError:
# IBE must propagate — see OrchestratorBlock class docstring.
# Log the billing failure here so the discarded tool result
# is traceable before the loop aborts.
logger.warning(
"Insufficient balance charging for tool node %s after "
"successful execution; agent loop will be aborted",
sink_node_id,
)
raise
except Exception:
# Non-billing charge failures (DB outage, network, etc.)
# must NOT propagate to the outer except handler because
# the tool itself succeeded. Re-raising would mark the
# tool as failed (_is_error=True), causing the LLM to
# retry side-effectful operations. Log and continue.
logger.exception(
"Unexpected error charging for tool node %s; "
"tool execution was successful",
sink_node_id,
)
tool_cost = 0
if tool_cost > 0:
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
@@ -1133,18 +1229,26 @@ class OrchestratorBlock(Block):
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(
resp = _create_tool_response(
tool_call.id, tool_response_content, responses_api=responses_api
)
resp["_is_error"] = False
return resp
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.warning("Tool execution with manager failed: %s", e)
# Return error response
return _create_tool_response(
logger.warning("Tool execution with manager failed: %s", e, exc_info=True)
# Return a generic error to the LLM — internal exception messages
# may contain server paths, DB details, or infrastructure info.
resp = _create_tool_response(
tool_call.id,
f"Tool execution failed: {e}",
"Tool execution failed due to an internal error",
responses_api=responses_api,
)
resp["_is_error"] = True
return resp
async def _agent_mode_llm_caller(
self,
@@ -1244,13 +1348,16 @@ class OrchestratorBlock(Block):
content = str(raw_content)
else:
content = "Tool executed successfully"
tool_failed = content.startswith("Tool execution failed:")
tool_failed = result.get("_is_error", True)
return ToolCallResult(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
content=content,
is_error=tool_failed,
)
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.error("Tool execution failed: %s", e)
return ToolCallResult(
@@ -1370,9 +1477,13 @@ class OrchestratorBlock(Block):
"arguments": tc.arguments,
},
)
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
# Catch all errors (validation, network, API) so that the block
# surfaces them as user-visible output instead of crashing.
# Catch all OTHER errors (validation, network, API) so that
# the block surfaces them as user-visible output instead of
# crashing.
yield "error", str(e)
return
@@ -1450,11 +1561,14 @@ class OrchestratorBlock(Block):
text = content
else:
text = json.dumps(content)
tool_failed = text.startswith("Tool execution failed:")
tool_failed = result.get("_is_error", True)
return {
"content": [{"type": "text", "text": text}],
"isError": tool_failed,
}
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.error("SDK tool execution failed: %s", e)
return {
@@ -1733,11 +1847,15 @@ class OrchestratorBlock(Block):
await pending_task
except (asyncio.CancelledError, StopAsyncIteration):
pass
except InsufficientBalanceError:
# IBE must propagate — see class docstring. The `finally`
# block below still runs and records partial token usage.
raise
except Exception as e:
# Surface SDK errors as user-visible output instead of crashing,
# consistent with _execute_tools_agent_mode error handling.
# Don't return yet — fall through to merge_stats below so
# partial token usage is always recorded.
# Surface OTHER SDK errors as user-visible output instead
# of crashing, consistent with _execute_tools_agent_mode
# error handling. Don't return yet — fall through to
# merge_stats below so partial token usage is always recorded.
sdk_error = e
finally:
# Always record usage stats, even on error. The SDK may have

View File

@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock charge_node_usage (called after successful tool execution).
# Returns (cost, remaining_balance). Must be AsyncMock because it is
# an async method and is directly awaited in _execute_single_tool_with_manager.
# Use a non-zero cost so the merge_stats branch is exercised.
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
# Verify charge_node_usage was actually called for the successful
# tool execution — this guards against regressions where the
# post-execution tool charging is accidentally removed.
assert mock_execution_processor.charge_node_usage.call_count == 1
@pytest.mark.asyncio
async def test_orchestrator_traditional_mode_default():

View File

@@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
# Mock charge_node_usage (called after successful tool execution).
# Must be AsyncMock because it is async and is awaited in
# _execute_single_tool_with_manager — a plain MagicMock would
# return a non-awaitable tuple and TypeError out, then be
# silently swallowed by the orchestrator's catch-all.
mock_execution_processor.charge_node_usage = AsyncMock(
return_value=(0, 0)
)
async for output_name, output_value in block.run(
input_data,

View File

@@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
ep.execution_stats_lock = threading.Lock()
ns = MagicMock(error=None)
ep.on_node_execution = AsyncMock(return_value=ns)
# Mock charge_node_usage (called after successful tool execution).
# Must be AsyncMock because it is async and is awaited in
# _execute_single_tool_with_manager — a plain MagicMock would return a
# non-awaitable tuple and TypeError out, then be silently swallowed by
# the orchestrator's catch-all.
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
block, "_create_tool_node_signatures", return_value=tool_sigs

View File

@@ -197,6 +197,15 @@ class ChatConfig(BaseSettings):
description="Maximum number of retries for transient API errors "
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
)
claude_agent_cross_user_prompt_cache: bool = Field(
default=True,
description="Enable cross-user prompt caching via SystemPromptPreset. "
"The Claude Code default prompt becomes a cacheable prefix shared "
"across all users, and our custom prompt is appended after it. "
"Dynamic sections (working dir, git status, auto-memory) are excluded "
"from the prefix. Set to False to fall back to passing the system "
"prompt as a raw string.",
)
claude_agent_cli_path: str | None = Field(
default=None,
description="Optional explicit path to a Claude Code CLI binary. "

View File

@@ -0,0 +1,555 @@
"""Tests for context fallback paths introduced in fix/copilot-transcript-resume-gate.
Scenario table
==============
| # | use_resume | transcript_msg_count | gap | target_tokens | Expected output |
|---|------------|----------------------|---------|---------------|--------------------------------------------|
| A | True | covers all | empty | None | bare message (--resume has full context) |
| B | True | stale | 2 msgs | None | gap context prepended |
| C | True | stale | 2 msgs | 50_000 | gap compressed to budget, prepended |
| D | False | 0 | N/A | None | full session compressed, prepended |
| E | False | 0 | N/A | 50_000 | full session compressed to budget |
| F | False | 2 (partial) | 2 msgs | None | full session compressed (not just gap; |
| | | | | | CLI has zero context without --resume) |
| G | False | 2 (partial) | 2 msgs | 50_000 | full session compressed to budget |
| H | False | covers all | empty | None | full session compressed |
| | | | | | (NOT bare message — the bug that was fixed)|
| I | False | covers all | empty | 50_000 | full session compressed to tight budget |
| J | False | 2 (partial) | n/a | None | exactly ONE compression call (full prior) |
Compression unit tests
=======================
| # | Input | target_tokens | Expected |
|---|----------------------|---------------|-----------------------------------------------|
| K | [] | None | ([], False) — empty guard |
| L | [1 msg] | None | ([msg], False) — single-msg guard |
| M | [2+ msgs] | None | target_tokens=None forwarded to _run_compression |
| N | [2+ msgs] | 30_000 | target_tokens=30_000 forwarded |
| O | [2+ msgs], run fails | None | returns originals, False |
"""
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import _build_query_message, _compress_messages
from backend.util.prompt import CompressResult
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(messages: list[ChatMessage]) -> ChatSession:
now = datetime.now(UTC)
return ChatSession(
session_id="test-session",
user_id="user-1",
messages=messages,
title="test",
usage=[],
started_at=now,
updated_at=now,
)
def _msgs(*pairs: tuple[str, str]) -> list[ChatMessage]:
return [ChatMessage(role=r, content=c) for r, c in pairs]
def _passthrough_compress(target_tokens=None):
"""Return a mock that passes messages through and records its call args."""
calls: list[tuple[list, int | None]] = []
async def _mock(msgs, tok=None):
calls.append((msgs, tok))
return msgs, False
_mock.calls = calls # type: ignore[attr-defined]
return _mock
# ---------------------------------------------------------------------------
# _build_query_message — scenario AJ
# ---------------------------------------------------------------------------
class TestBuildQueryMessageResume:
"""use_resume=True paths (--resume supplies history; only inject gap if stale)."""
@pytest.mark.asyncio
async def test_scenario_a_transcript_current_returns_bare_message(self):
"""Scenario A: --resume covers full context → no prefix injected."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
result, compacted = await _build_query_message(
"q2", session, use_resume=True, transcript_msg_count=2, session_id="s"
)
assert result == "q2"
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_b_stale_transcript_injects_gap(self, monkeypatch):
"""Scenario B: stale transcript → gap context prepended."""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, compacted = await _build_query_message(
"q3", session, use_resume=True, transcript_msg_count=2, session_id="s"
)
assert "<conversation_history>" in result
assert "q2" in result
assert "a2" in result
assert "Now, the user says:\nq3" in result
# q1/a1 are covered by the transcript — must NOT appear in gap context
assert "q1" not in result
@pytest.mark.asyncio
async def test_scenario_c_stale_transcript_passes_target_tokens(self, monkeypatch):
"""Scenario C: target_tokens is forwarded to _compress_messages for the gap."""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q3",
session,
use_resume=True,
transcript_msg_count=2,
session_id="s",
target_tokens=50_000,
)
assert captured == [50_000]
class TestBuildQueryMessageNoResumeNoTranscript:
"""use_resume=False, transcript_msg_count=0 — full session compressed."""
@pytest.mark.asyncio
async def test_scenario_d_full_session_compressed(self, monkeypatch):
"""Scenario D: no resume, no transcript → compress all prior messages."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, compacted = await _build_query_message(
"q2", session, use_resume=False, transcript_msg_count=0, session_id="s"
)
assert "<conversation_history>" in result
assert "q1" in result
assert "a1" in result
assert "Now, the user says:\nq2" in result
@pytest.mark.asyncio
async def test_scenario_e_passes_target_tokens_to_compression(self, monkeypatch):
"""Scenario E: target_tokens forwarded to _compress_messages."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q2",
session,
use_resume=False,
transcript_msg_count=0,
session_id="s",
target_tokens=15_000,
)
assert captured == [15_000]
class TestBuildQueryMessageNoResumeWithTranscript:
"""use_resume=False, transcript_msg_count > 0 — gap or full-session fallback."""
@pytest.mark.asyncio
async def test_scenario_f_no_resume_always_injects_full_session(self, monkeypatch):
"""Scenario F: use_resume=False with transcript_msg_count > 0 still injects
the FULL prior session — not just the gap since the transcript end.
When there is no --resume the CLI starts with zero context, so injecting
only the post-transcript gap would silently drop all transcript-covered
history. The correct fix is to always compress the full session.
"""
session = _make_session(
_msgs(
("user", "q1"), # transcript_msg_count=2 covers these
("assistant", "a1"),
("user", "q2"), # post-transcript gap starts here
("assistant", "a2"),
("user", "q3"), # current message
)
)
compressed_msgs: list[list] = []
async def _mock_compress(msgs, target_tokens=None):
compressed_msgs.append(list(msgs))
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, _ = await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=2, # transcript covers q1/a1 but no --resume
session_id="s",
)
assert "<conversation_history>" in result
# Full session must be injected — transcript-covered turns ARE included
assert "q1" in result
assert "a1" in result
assert "q2" in result
assert "a2" in result
assert "Now, the user says:\nq3" in result
# Compressed exactly once with all 4 prior messages
assert len(compressed_msgs) == 1
assert len(compressed_msgs[0]) == 4
@pytest.mark.asyncio
async def test_scenario_g_no_resume_passes_target_tokens(self, monkeypatch):
"""Scenario G: target_tokens forwarded when use_resume=False + transcript_msg_count > 0."""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
target_tokens=50_000,
)
assert captured == [50_000]
@pytest.mark.asyncio
async def test_scenario_h_no_resume_transcript_current_injects_full_session(
self, monkeypatch
):
"""Scenario H: the bug that was fixed.
Old code path: use_resume=False, transcript_msg_count covers all prior
messages → gap sub-path: gap = [] → ``return current_message, False``
→ model received ZERO context (bare message only).
New code path: use_resume=False always compresses the full prior session
regardless of transcript_msg_count — model always gets context.
"""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
result, _ = await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=4, # covers ALL prior → old code returned bare msg
session_id="s",
)
# NEW: must inject full session, NOT return bare message
assert result != "q3"
assert "<conversation_history>" in result
assert "q1" in result
assert "Now, the user says:\nq3" in result
@pytest.mark.asyncio
async def test_scenario_i_no_resume_target_tokens_forwarded_any_transcript_count(
self, monkeypatch
):
"""Scenario I: target_tokens forwarded even when transcript_msg_count covers all."""
session = _make_session(
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
)
captured: list[int | None] = []
async def _mock_compress(msgs, target_tokens=None):
captured.append(target_tokens)
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q2",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
target_tokens=15_000,
)
assert 15_000 in captured
@pytest.mark.asyncio
async def test_scenario_j_no_resume_single_compression_call(self, monkeypatch):
"""Scenario J: use_resume=False always makes exactly ONE compression call
(the full session), regardless of transcript coverage.
This verifies there is no two-step gap+fallback pattern for no-resume —
compression is called once with the full prior session.
"""
session = _make_session(
_msgs(
("user", "q1"),
("assistant", "a1"),
("user", "q2"),
("assistant", "a2"),
("user", "q3"),
)
)
call_count = 0
async def _mock_compress(msgs, target_tokens=None):
nonlocal call_count
call_count += 1
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages", _mock_compress
)
await _build_query_message(
"q3",
session,
use_resume=False,
transcript_msg_count=2,
session_id="s",
)
assert call_count == 1
# ---------------------------------------------------------------------------
# _compress_messages — unit tests KO
# ---------------------------------------------------------------------------
class TestCompressMessages:
@pytest.mark.asyncio
async def test_scenario_k_empty_list_returns_empty(self):
"""Scenario K: empty input → short-circuit, no compression."""
result, compacted = await _compress_messages([])
assert result == []
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_l_single_message_returns_as_is(self):
"""Scenario L: single message → short-circuit (< 2 guard)."""
msg = ChatMessage(role="user", content="hello")
result, compacted = await _compress_messages([msg])
assert result == [msg]
assert compacted is False
@pytest.mark.asyncio
async def test_scenario_m_target_tokens_none_forwarded(self):
"""Scenario M: target_tokens=None forwarded to _run_compression."""
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
fake_result = CompressResult(
messages=[
{"role": "user", "content": "q"},
{"role": "assistant", "content": "a"},
],
token_count=10,
was_compacted=False,
original_token_count=10,
)
with patch(
"backend.copilot.sdk.service._run_compression",
new_callable=AsyncMock,
return_value=fake_result,
) as mock_run:
await _compress_messages(msgs, target_tokens=None)
mock_run.assert_awaited_once()
_, kwargs = mock_run.call_args
assert kwargs.get("target_tokens") is None
@pytest.mark.asyncio
async def test_scenario_n_explicit_target_tokens_forwarded(self):
"""Scenario N: explicit target_tokens forwarded to _run_compression."""
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
fake_result = CompressResult(
messages=[{"role": "user", "content": "summary"}],
token_count=5,
was_compacted=True,
original_token_count=50,
)
with patch(
"backend.copilot.sdk.service._run_compression",
new_callable=AsyncMock,
return_value=fake_result,
) as mock_run:
result, compacted = await _compress_messages(msgs, target_tokens=30_000)
mock_run.assert_awaited_once()
_, kwargs = mock_run.call_args
assert kwargs.get("target_tokens") == 30_000
assert compacted is True
@pytest.mark.asyncio
async def test_scenario_o_run_compression_exception_returns_originals(self):
"""Scenario O: _run_compression raises → return original messages, False."""
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
with patch(
"backend.copilot.sdk.service._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("compression timeout"),
):
result, compacted = await _compress_messages(msgs)
assert result == msgs
assert compacted is False
@pytest.mark.asyncio
async def test_compaction_messages_filtered_before_compression(self):
"""filter_compaction_messages is applied before _run_compression is called."""
# A compaction message is one with role=assistant and specific content pattern.
# We verify that only real messages reach _run_compression.
from backend.copilot.sdk.service import filter_compaction_messages
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="a"),
]
# filter_compaction_messages should not remove these plain messages
filtered = filter_compaction_messages(msgs)
assert len(filtered) == len(msgs)
# ---------------------------------------------------------------------------
# target_tokens threading — _retry_target_tokens values match expectations
# ---------------------------------------------------------------------------
class TestRetryTargetTokens:
def test_first_retry_uses_first_slot(self):
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
assert _RETRY_TARGET_TOKENS[0] == 50_000
def test_second_retry_uses_second_slot(self):
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
assert _RETRY_TARGET_TOKENS[1] == 15_000
def test_second_slot_smaller_than_first(self):
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
assert _RETRY_TARGET_TOKENS[1] < _RETRY_TARGET_TOKENS[0]
# ---------------------------------------------------------------------------
# Single-message session edge cases
# ---------------------------------------------------------------------------
class TestSingleMessageSessions:
@pytest.mark.asyncio
async def test_no_resume_single_message_returns_bare(self):
"""First turn (1 message): no prior history to inject."""
session = _make_session([ChatMessage(role="user", content="hello")])
result, compacted = await _build_query_message(
"hello", session, use_resume=False, transcript_msg_count=0, session_id="s"
)
assert result == "hello"
assert compacted is False
@pytest.mark.asyncio
async def test_resume_single_message_returns_bare(self):
"""First turn with resume flag: transcript is empty so no gap."""
session = _make_session([ChatMessage(role="user", content="hello")])
result, compacted = await _build_query_message(
"hello", session, use_resume=True, transcript_msg_count=0, session_id="s"
)
assert result == "hello"
assert compacted is False

View File

@@ -6,6 +6,7 @@ import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import (
_BARE_MESSAGE_TOKEN_FLOOR,
_build_query_message,
_format_conversation_context,
)
@@ -130,6 +131,34 @@ async def test_build_query_resume_up_to_date():
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_resume_misaligned_watermark():
"""With --resume and watermark pointing at a user message, skip gap."""
# Simulates a deleted message shifting DB positions so the watermark
# lands on a user turn instead of the expected assistant turn.
session = _make_session(
[
ChatMessage(role="user", content="turn 1"),
ChatMessage(role="assistant", content="reply 1"),
ChatMessage(
role="user", content="turn 2"
), # ← watermark points here (role=user)
ChatMessage(role="assistant", content="reply 2"),
ChatMessage(role="user", content="turn 3"),
]
)
result, was_compacted = await _build_query_message(
"turn 3",
session,
use_resume=True,
transcript_msg_count=3, # prior[2].role == "user" — misaligned
session_id="test-session",
)
# Misaligned watermark → skip gap, return bare message
assert result == "turn 3"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_resume_stale_transcript():
"""With --resume and stale transcript, gap context is prepended."""
@@ -204,7 +233,7 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
)
# Mock _compress_messages to return the messages as-is
async def _mock_compress(msgs):
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
@@ -237,7 +266,7 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
]
)
async def _mock_compress(msgs):
async def _mock_compress(msgs, target_tokens=None):
return msgs, True # Simulate actual compaction
monkeypatch.setattr(
@@ -253,3 +282,85 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
session_id="test-session",
)
assert was_compacted is True
@pytest.mark.asyncio
async def test_build_query_no_resume_at_token_floor():
"""When target_tokens is at or below the floor, return bare message.
This is the final escape hatch: if the retry budget is exhausted and
even the most aggressive compression might not fit, skip history
injection entirely so the user always gets a response.
"""
session = _make_session(
[
ChatMessage(role="user", content="old question"),
ChatMessage(role="assistant", content="old answer"),
ChatMessage(role="user", content="new question"),
]
)
result, was_compacted = await _build_query_message(
"new question",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR,
)
# At the floor threshold, no history is injected
assert result == "new question"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_below_token_floor():
"""target_tokens strictly below floor also returns bare message."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR - 1,
)
assert result == "new"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_above_token_floor_compresses(monkeypatch):
"""target_tokens just above the floor still triggers compression."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
async def _mock_compress(msgs, target_tokens=None):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR + 1,
)
# Above the floor → history is injected (not the bare message)
assert "<conversation_history>" in result
assert "Now, the user says:\nnew" in result

View File

@@ -7,6 +7,7 @@ tests will catch it immediately.
"""
import inspect
from typing import cast
import pytest
@@ -90,6 +91,39 @@ def test_agent_options_accepts_required_fields():
assert opts.cwd == "/tmp"
def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections():
"""Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces.
The production code always includes ``exclude_dynamic_sections=True`` in the preset
dict. This compat test mirrors that exact shape so any SDK version that starts
rejecting unknown keys will be caught here rather than at runtime.
"""
from claude_agent_sdk import ClaudeAgentOptions
from claude_agent_sdk.types import SystemPromptPreset
from .service import _build_system_prompt_value
# Call the production helper directly so this test is tied to the real
# dict shape rather than a hand-rolled copy.
preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True)
assert isinstance(
preset, dict
), "_build_system_prompt_value must return a dict when caching is on"
sdk_preset = cast(SystemPromptPreset, preset)
opts = ClaudeAgentOptions(system_prompt=sdk_preset)
assert opts.system_prompt == sdk_preset
def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off():
"""When cross_user_cache=False (e.g. on --resume turns), the helper must return
a plain string so the preset+resume crash is avoided."""
from .service import _build_system_prompt_value
result = _build_system_prompt_value("my prompt", cross_user_cache=False)
assert result == "my prompt", "Must return the raw string, not a preset dict"
def test_agent_options_accepts_all_our_fields():
"""Comprehensive check of every field we use in service.py."""
from claude_agent_sdk import ClaudeAgentOptions

View File

@@ -29,6 +29,7 @@ from claude_agent_sdk import (
ToolResultBlock,
ToolUseBlock,
)
from claude_agent_sdk.types import SystemPromptPreset
from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from opentelemetry import trace as otel_trace
@@ -260,6 +261,11 @@ class ReducedContext(NamedTuple):
resume_file: str | None
transcript_lost: bool
tried_compaction: bool
# Token budget for history compression on the DB-message fallback path.
# None means "use model-aware default". Halved on each retry so
# compress_context applies progressively more aggressive reduction
# (LLM summarize → content truncate → middle-out delete → first/last trim).
target_tokens: int | None = None
@dataclass
@@ -304,6 +310,10 @@ class _RetryState:
adapter: SDKResponseAdapter
transcript_builder: TranscriptBuilder
usage: _TokenUsage
# Token budget for history compression on retries (DB-message fallback path).
# None = model-aware default. Halved each retry for progressively more
# aggressive compression (LLM summarize → truncate → middle-out → trim).
target_tokens: int | None = None
@dataclass
@@ -335,12 +345,34 @@ class _StreamContext:
lock: AsyncClusterLock
# Per-retry token budgets for the no-transcript (use_resume=False) path.
# When there is no CLI native session to --resume, context is built from DB
# messages via _format_conversation_context. For large sessions this text
# can exceed the model context window; each retry halves the token budget so
# compress_context applies progressively more aggressive reduction:
# LLM summarize → content truncate → middle-out delete → first/last trim.
# Index 0 = first retry, 1 = second retry; last value applies beyond that.
_RETRY_TARGET_TOKENS: tuple[int, ...] = (50_000, 15_000)
# Below this token budget the model context is so tight that injecting any
# conversation history would likely exceed the limit regardless of content.
# _build_query_message returns the bare message when target_tokens falls to
# or below this floor, giving the user a response instead of a hard error.
_BARE_MESSAGE_TOKEN_FLOOR: int = 5_000
# Tight token budget for seeding the transcript builder on turns where no
# CLI native session exists. Kept below _RETRY_TARGET_TOKENS[0] so the
# seeded JSONL upload stays compact and future gap injections are small.
_SEED_TARGET_TOKENS: int = 30_000
async def _reduce_context(
transcript_content: str,
tried_compaction: bool,
session_id: str,
sdk_cwd: str,
log_prefix: str,
attempt: int = 1,
) -> ReducedContext:
"""Prepare reduced context for a retry attempt.
@@ -348,9 +380,19 @@ async def _reduce_context(
On subsequent retries (or if compaction fails), drops the transcript
entirely so the query is rebuilt from DB messages only.
`transcript_lost` is True when the transcript was dropped (caller
should set `skip_transcript_upload`).
When no transcript is available (use_resume=False fallback path), returns
a decreasing ``target_tokens`` budget so ``compress_context`` applies
progressively more aggressive reduction (LLM summarize → content truncate
→ middle-out delete → first/last trim). The budget applies in
``_build_query_message`` and is halved on each retry.
``transcript_lost`` is True when the transcript was dropped (caller
should set ``skip_transcript_upload``).
"""
# Token budget for the DB fallback on this attempt (no-transcript path).
idx = max(0, attempt - 1)
retry_target = _RETRY_TARGET_TOKENS[min(idx, len(_RETRY_TARGET_TOKENS) - 1)]
# First retry: try compacting our transcript builder state.
# Note: the CLI native --resume file is not updated with the compacted
# content (it would require emitting CLI-native JSONL format), so the
@@ -374,9 +416,14 @@ async def _reduce_context(
return ReducedContext(tb, False, None, False, True)
logger.warning("%s Compaction failed, dropping transcript", log_prefix)
# Subsequent retry or compaction failed: drop transcript entirely
logger.warning("%s Dropping transcript, rebuilding from DB messages", log_prefix)
return ReducedContext(TranscriptBuilder(), False, None, True, True)
# Subsequent retry or compaction failed: drop transcript entirely.
# Return retry_target so the caller compresses DB messages to that budget.
logger.warning(
"%s Dropping transcript, rebuilding from DB messages" " (target_tokens=%d)",
log_prefix,
retry_target,
)
return ReducedContext(TranscriptBuilder(), False, None, True, True, retry_target)
def _append_error_marker(
@@ -705,6 +752,34 @@ def _is_fallback_stderr(line: str) -> bool:
return "fallback model" in line.lower()
def _build_system_prompt_value(
system_prompt: str,
cross_user_cache: bool,
) -> str | SystemPromptPreset:
"""Build the ``system_prompt`` argument for :class:`ClaudeAgentOptions`.
When *cross_user_cache* is enabled, returns a :class:`SystemPromptPreset`
dict so the Claude Code default prompt becomes a cacheable prefix shared
across all users; our custom *system_prompt* is appended after it.
When disabled (or if the SDK is too old to support ``SystemPromptPreset``),
the raw *system_prompt* string is returned unchanged.
An empty *system_prompt* is accepted: the preset dict will have
``append: ""`` which the SDK treats as no custom suffix.
"""
if cross_user_cache:
logger.debug("Using SystemPromptPreset for cross-user prompt cache")
return SystemPromptPreset(
type="preset",
preset="claude_code",
append=system_prompt,
exclude_dynamic_sections=True,
)
logger.debug("Cross-user prompt cache disabled, using raw string")
return system_prompt
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
@@ -801,6 +876,7 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
async def _compress_messages(
messages: list[ChatMessage],
target_tokens: int | None = None,
) -> tuple[list[ChatMessage], bool]:
"""Compress a list of messages if they exceed the token threshold.
@@ -809,6 +885,10 @@ async def _compress_messages(
`_compress_messages` and `compact_transcript` share this helper so
client acquisition and error handling are consistent.
``target_tokens`` sets a hard ceiling for the compressed output so
callers can enforce a tighter budget on retries. When ``None``,
``compress_context`` uses the model-aware default.
See also:
`_run_compression` — shared compression with timeout guards.
`compact_transcript` — compresses JSONL transcript entries.
@@ -832,7 +912,9 @@ async def _compress_messages(
messages_dict.append(msg_dict)
try:
result = await _run_compression(messages_dict, config.model, "[SDK]")
result = await _run_compression(
messages_dict, config.model, "[SDK]", target_tokens=target_tokens
)
except Exception as exc:
# Guard against timeouts or unexpected errors in compression —
# return the original messages so the caller can proceed without
@@ -961,44 +1043,139 @@ async def _build_query_message(
use_resume: bool,
transcript_msg_count: int,
session_id: str,
target_tokens: int | None = None,
) -> tuple[str, bool]:
"""Build the query message with appropriate context.
When ``use_resume=True``, the CLI has the full session via ``--resume``;
only a gap-fill prefix is injected when the transcript is stale.
When ``use_resume=False``, the CLI starts a fresh session with no prior
context, so the full prior session is always compressed and injected via
``_format_conversation_context``. ``compress_context`` handles size
reduction internally (LLM summarize → content truncate → middle-out delete
→ first/last trim). ``target_tokens`` decreases on each retry to force
progressively more aggressive compression when the first attempt exceeds
context limits.
Returns:
Tuple of (query_message, was_compacted).
"""
msg_count = len(session.messages)
prior = session.messages[:-1] # all turns except the current user message
logger.info(
"[SDK] [%s] Context path: use_resume=%s, transcript_msg_count=%d,"
" db_msg_count=%d, target_tokens=%s",
session_id[:8],
use_resume,
transcript_msg_count,
msg_count,
target_tokens,
)
if use_resume and transcript_msg_count > 0:
if transcript_msg_count < msg_count - 1:
gap = session.messages[transcript_msg_count:-1]
compressed, was_compressed = await _compress_messages(gap)
# Sanity-check the watermark: the last covered position should be
# an assistant turn. A user-role message here means the count is
# misaligned (e.g. a message was deleted and DB positions shifted).
# Skip the gap rather than injecting wrong context — the CLI session
# loaded via --resume still has good history.
if prior[transcript_msg_count - 1].role != "assistant":
logger.warning(
"[SDK] [%s] Watermark misaligned: prior[%d].role=%r"
" (expected 'assistant') — skipping gap to avoid"
" injecting wrong context (transcript=%d, db=%d)",
session_id[:8],
transcript_msg_count - 1,
prior[transcript_msg_count - 1].role,
transcript_msg_count,
msg_count,
)
return current_message, False
gap = prior[transcript_msg_count:]
compressed, was_compressed = await _compress_messages(gap, target_tokens)
gap_context = _format_conversation_context(compressed)
if gap_context:
logger.info(
"[SDK] Transcript stale: covers %d of %d messages, "
"gap=%d (compressed=%s)",
"gap=%d (compressed=%s), gap_context_bytes=%d",
transcript_msg_count,
msg_count,
len(gap),
was_compressed,
len(gap_context),
)
return (
f"{gap_context}\n\nNow, the user says:\n{current_message}",
was_compressed,
)
logger.warning(
"[SDK] [%s] Transcript stale: gap produced empty context"
" (%d msgs, transcript=%d/%d) — sending message without gap prefix",
session_id[:8],
len(gap),
transcript_msg_count,
msg_count,
)
else:
logger.info(
"[SDK] [%s] --resume covers full context (%d messages)",
session_id[:8],
transcript_msg_count,
)
return current_message, False
elif not use_resume and msg_count > 1:
# No --resume: the CLI starts a fresh session with no prior context.
# Injecting only the post-transcript gap would omit the transcript-covered
# prefix entirely, so always compress the full prior session here.
# compress_context handles size reduction internally (LLM summarize →
# content truncate → middle-out delete → first/last trim).
# Final escape hatch: if the token budget is at or below the floor,
# the model context is so tight that even fully compressed history
# would risk a "prompt too long" error. Return the bare message so
# the user always gets a response rather than a hard failure.
if target_tokens is not None and target_tokens <= _BARE_MESSAGE_TOKEN_FLOOR:
logger.warning(
"[SDK] [%s] target_tokens=%d at or below floor (%d) —"
" skipping history injection to guarantee response delivery"
" (session has %d messages)",
session_id[:8],
target_tokens,
_BARE_MESSAGE_TOKEN_FLOOR,
msg_count,
)
return current_message, False
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({msg_count} messages) — no transcript for --resume"
"[SDK] [%s] No --resume for %d-message session — compressing"
" full session history (pod affinity issue or first turn after"
" restore failure); target_tokens=%s",
session_id[:8],
msg_count,
target_tokens,
)
compressed, was_compressed = await _compress_messages(session.messages[:-1])
compressed, was_compressed = await _compress_messages(prior, target_tokens)
history_context = _format_conversation_context(compressed)
if history_context:
logger.info(
"[SDK] [%s] Fallback context built: compressed=%s," " context_bytes=%d",
session_id[:8],
was_compressed,
len(history_context),
)
return (
f"{history_context}\n\nNow, the user says:\n{current_message}",
was_compressed,
)
logger.warning(
"[SDK] [%s] Fallback context empty after compression"
" (%d messages) — sending message without history",
session_id[:8],
len(prior),
)
return current_message, False
@@ -1927,6 +2104,48 @@ async def _run_stream_attempt(
)
async def _seed_transcript(
session: ChatSession,
transcript_builder: TranscriptBuilder,
transcript_covers_prefix: bool,
transcript_msg_count: int,
log_prefix: str,
) -> tuple[str, bool, int]:
"""Seed the transcript builder from compressed DB messages.
Called when ``use_resume=False`` and no prior transcript exists in storage
so that ``upload_transcript`` saves a compact version for future turns.
This ensures the next turn can use the full-session compression path with
the benefit of an already-compressed baseline, and a restored CLI session
on the next pod gets a usable compact base even for sessions that started
on old pods.
Returns ``(transcript_content, transcript_covers_prefix, transcript_msg_count)``
updated values — unchanged if seeding is not possible.
"""
if len(session.messages) <= 1:
return "", transcript_covers_prefix, transcript_msg_count
_prior = session.messages[:-1]
_comp, _ = await _compress_messages(_prior, _SEED_TARGET_TOKENS)
if not _comp:
return "", transcript_covers_prefix, transcript_msg_count
_seeded = _session_messages_to_transcript(_comp)
if not _seeded or not validate_transcript(_seeded):
return "", transcript_covers_prefix, transcript_msg_count
transcript_builder.load_previous(_seeded, log_prefix=log_prefix)
logger.info(
"%s Seeded transcript from %d compressed DB messages"
" for next-turn upload (seed_target_tokens=%d)",
log_prefix,
len(_comp),
_SEED_TARGET_TOKENS,
)
return _seeded, True, len(_prior)
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -2198,9 +2417,20 @@ async def stream_chat_completion_sdk(
# Builder loaded but CLI native session not available.
# --resume will not be used this turn; upload after turn
# will seed the native session for the next turn.
#
# Still record transcript_msg_count so _build_query_message
# can use the transcript-aware gap path (inject only new
# messages since the transcript end) instead of compressing
# the full DB history. This avoids prompt-too-long on
# large sessions where the CLI session is temporarily
# unavailable (e.g. mixed-version rolling deployment).
transcript_msg_count = dl.message_count
logger.info(
"%s CLI session not restored — running without --resume this turn",
"%s CLI session not restored — running without"
" --resume this turn (transcript_msg_count=%d for"
" gap-aware fallback)",
log_prefix,
transcript_msg_count,
)
else:
logger.warning("%s Transcript downloaded but invalid", log_prefix)
@@ -2295,8 +2525,19 @@ async def stream_chat_completion_sdk(
sid,
)
# Use SystemPromptPreset for cross-user prompt caching.
# WORKAROUND: CLI 2.1.97 (sdk 0.1.58) exits code 1 when
# excludeDynamicSections=True is in the initialize request AND
# --resume is active. Disable the preset on resumed turns.
# Turn 1 still gets the preset (no --resume).
_cross_user = config.claude_agent_cross_user_prompt_cache and not use_resume
system_prompt_value = _build_system_prompt_value(
system_prompt,
cross_user_cache=_cross_user,
)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"system_prompt": system_prompt_value,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": allowed,
"disallowed_tools": disallowed,
@@ -2425,6 +2666,22 @@ async def stream_chat_completion_sdk(
if attachments.hint:
query_message = f"{query_message}\n\n{attachments.hint}"
# When running without --resume and no prior transcript in storage,
# seed the transcript builder from compressed DB messages so that
# upload_transcript saves a compact version for future turns.
if not use_resume and not transcript_content and not skip_transcript_upload:
(
transcript_content,
transcript_covers_prefix,
transcript_msg_count,
) = await _seed_transcript(
session,
transcript_builder,
transcript_covers_prefix,
transcript_msg_count,
log_prefix,
)
tried_compaction = False
# Build the per-request context carrier (shared across attempts).
@@ -2507,12 +2764,14 @@ async def stream_chat_completion_sdk(
session_id,
sdk_cwd,
log_prefix,
attempt=attempt,
)
state.transcript_builder = ctx.builder
state.use_resume = ctx.use_resume
state.resume_file = ctx.resume_file
tried_compaction = ctx.tried_compaction
state.transcript_msg_count = 0
state.target_tokens = ctx.target_tokens
if ctx.transcript_lost:
skip_transcript_upload = True
@@ -2530,9 +2789,18 @@ async def stream_chat_completion_sdk(
# T2+ retry without --resume: do not pass --session-id.
# The T1 session file already exists at that path; re-using
# the same ID would fail with "Session ID already in use".
# The upload guard skips T2+ no-resume turns anyway.
sdk_options_kwargs_retry.pop("resume", None)
sdk_options_kwargs_retry.pop("session_id", None)
# Recompute system_prompt for retry — ctx.use_resume may have
# changed (context reduction enabled --resume). CLI 2.1.97
# crashes when excludeDynamicSections=True is combined with
# --resume, so disable the cross-user preset on resumed turns.
_cross_user_retry = (
config.claude_agent_cross_user_prompt_cache and not ctx.use_resume
)
sdk_options_kwargs_retry["system_prompt"] = _build_system_prompt_value(
system_prompt, cross_user_cache=_cross_user_retry
)
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
state.query_message, state.was_compacted = await _build_query_message(
current_message,
@@ -2540,6 +2808,7 @@ async def stream_chat_completion_sdk(
state.use_resume,
state.transcript_msg_count,
session_id,
target_tokens=state.target_tokens,
)
if attachments.hint:
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
@@ -3025,6 +3294,21 @@ async def stream_chat_completion_sdk(
# the shielded inner coroutine continues running to completion so the
# upload is not lost. This is intentional and matches the pattern
# used for upload_transcript immediately above.
#
# NOTE: upload is attempted regardless of state.use_resume — even when
# this turn ran without --resume (restore failed or first T2+ on a new
# pod), the T1 session file at the expected path may still be present
# and should be re-uploaded so the next turn can resume from it.
# upload_cli_session silently skips when the file is absent, so this is
# always safe.
#
# Intentionally NOT gated on skip_transcript_upload: that flag is set
# when our custom JSONL transcript is dropped (transcript_lost=True on
# reduced-context retries) but the CLI's native session file is written
# independently. Blocking CLI upload on transcript_lost would prevent
# T1 prompt-too-long retries from uploading their valid session file,
# breaking --resume on the next pod. The ended_with_stream_error gate
# above already covers actual turn failures.
if (
config.claude_agent_use_resume
and user_id
@@ -3032,9 +3316,15 @@ async def stream_chat_completion_sdk(
and session is not None
and state is not None
and not ended_with_stream_error
and not skip_transcript_upload
and (not has_history or state.use_resume)
):
logger.info(
"%s Attempting CLI session upload"
" (use_resume=%s, has_history=%s, skip_transcript=%s)",
log_prefix,
state.use_resume,
has_history,
skip_transcript_upload,
)
try:
await asyncio.shield(
upload_cli_session(

View File

@@ -15,6 +15,7 @@ from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
from .conftest import build_test_transcript as _build_transcript
from .service import (
_RETRY_TARGET_TOKENS,
ReducedContext,
_is_prompt_too_long,
_is_tool_only_message,
@@ -208,6 +209,24 @@ class TestReduceContext:
assert ctx.transcript_lost is True
@pytest.mark.asyncio
async def test_drop_returns_target_tokens_attempt_1(self) -> None:
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=1)
assert ctx.transcript_lost is True
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[0]
@pytest.mark.asyncio
async def test_drop_returns_target_tokens_attempt_2(self) -> None:
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=2)
assert ctx.transcript_lost is True
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[1]
@pytest.mark.asyncio
async def test_drop_clamps_attempt_beyond_limits(self) -> None:
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=99)
assert ctx.transcript_lost is True
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[-1]
# ---------------------------------------------------------------------------
# _iter_sdk_messages

View File

@@ -8,7 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot import config as cfg_mod
from .service import (
_build_system_prompt_value,
_is_sdk_disconnect_error,
_normalize_model_name,
_prepare_file_attachments,
@@ -397,6 +400,7 @@ _CONFIG_ENV_VARS = (
"OPENAI_BASE_URL",
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
"CHAT_USE_CLAUDE_AGENT_SDK",
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
)
@@ -656,3 +660,62 @@ class TestSafeCloseSdkClient:
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
with pytest.raises(ValueError, match="invalid argument"):
await _safe_close_sdk_client(client, "[test]")
# ---------------------------------------------------------------------------
# SystemPromptPreset — cross-user prompt caching
# ---------------------------------------------------------------------------
class TestSystemPromptPreset:
"""Tests for _build_system_prompt_value — cross-user prompt caching."""
def test_preset_dict_structure_when_enabled(self):
"""When cross_user_cache is True, returns a _SystemPromptPreset dict."""
custom_prompt = "You are a helpful assistant."
result = _build_system_prompt_value(custom_prompt, cross_user_cache=True)
assert isinstance(result, dict)
assert result["type"] == "preset"
assert result["preset"] == "claude_code"
assert result["append"] == custom_prompt
assert result["exclude_dynamic_sections"] is True
def test_raw_string_when_disabled(self):
"""When cross_user_cache is False, returns the raw string."""
custom_prompt = "You are a helpful assistant."
result = _build_system_prompt_value(custom_prompt, cross_user_cache=False)
assert isinstance(result, str)
assert result == custom_prompt
def test_empty_string_with_cache_enabled(self):
"""Empty system_prompt with cross_user_cache=True produces append=''."""
result = _build_system_prompt_value("", cross_user_cache=True)
assert isinstance(result, dict)
assert result["type"] == "preset"
assert result["preset"] == "claude_code"
assert result["append"] == ""
assert result["exclude_dynamic_sections"] is True
def test_default_config_is_enabled(self, _clean_config_env):
"""The default value for claude_agent_cross_user_prompt_cache is True."""
cfg = cfg_mod.ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
assert cfg.claude_agent_cross_user_prompt_cache is True
def test_env_var_disables_cache(self, _clean_config_env, monkeypatch):
"""CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE=false disables caching."""
monkeypatch.setenv("CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE", "false")
cfg = cfg_mod.ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
assert cfg.claude_agent_cross_user_prompt_cache is False

View File

@@ -960,7 +960,7 @@ class TestRunCompression:
)
call_count = [0]
async def _compress_side_effect(*, messages, model, client):
async def _compress_side_effect(*, messages, model, client, target_tokens=None):
call_count[0] += 1
if client is not None:
# Simulate a hang that exceeds the timeout

View File

@@ -1179,6 +1179,7 @@ async def _run_compression(
messages: list[dict],
model: str,
log_prefix: str,
target_tokens: int | None = None,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback.
@@ -1187,6 +1188,12 @@ async def _run_compression(
truncation-based compression which drops older messages without
summarization.
``target_tokens`` sets a hard token ceiling for the compressed output.
When ``None``, ``compress_context`` derives the limit from the model's
context window. Pass a smaller value on retries to force more aggressive
compression — the compressor will LLM-summarize, content-truncate,
middle-out delete, and first/last trim until the result fits.
A 60-second timeout prevents a hung LLM call from blocking the
retry path indefinitely. The truncation fallback also has a
30-second timeout to guard against slow tokenization on very large
@@ -1196,18 +1203,27 @@ async def _run_compression(
if client is None:
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
compress_context(
messages=messages, model=model, client=None, target_tokens=target_tokens
),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)
try:
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=client),
compress_context(
messages=messages,
model=model,
client=client,
target_tokens=target_tokens,
),
timeout=_COMPACTION_TIMEOUT_SECONDS,
)
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
compress_context(
messages=messages, model=model, client=None, target_tokens=target_tokens
),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)

View File

@@ -349,7 +349,7 @@ class UserCreditBase(ABC):
CreditTransactionType.GRANT,
CreditTransactionType.TOP_UP,
]:
from backend.executor.manager import (
from backend.executor.billing import (
clear_insufficient_funds_notifications,
)
@@ -554,7 +554,7 @@ class UserCreditBase(ABC):
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
):
# Lazy import to avoid circular dependency with executor.manager
from backend.executor.manager import (
from backend.executor.billing import (
clear_insufficient_funds_notifications,
)

View File

@@ -852,6 +852,7 @@ class NodeExecutionStats(BaseModel):
output_token_count: int = 0
cache_read_token_count: int = 0
cache_creation_token_count: int = 0
cost: int = 0
extra_cost: int = 0
extra_steps: int = 0
provider_cost: float | None = None

View File

@@ -8,6 +8,7 @@ from prisma.models import User as PrismaUser
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
from backend.util.json import SafeJson
@@ -142,6 +143,7 @@ class UserCostSummary(BaseModel):
total_cache_read_tokens: int = 0
total_cache_creation_tokens: int = 0
request_count: int
cost_bearing_request_count: int = 0
class CostLogRow(BaseModel):
@@ -163,12 +165,27 @@ class CostLogRow(BaseModel):
cache_creation_tokens: int | None = None
class CostBucket(BaseModel):
bucket: str
count: int
class PlatformCostDashboard(BaseModel):
by_provider: list[ProviderCostSummary]
by_user: list[UserCostSummary]
total_cost_microdollars: int
total_requests: int
total_users: int
total_input_tokens: int = 0
total_output_tokens: int = 0
avg_input_tokens_per_request: float = 0.0
avg_output_tokens_per_request: float = 0.0
avg_cost_microdollars_per_request: float = 0.0
cost_p50_microdollars: float = 0.0
cost_p75_microdollars: float = 0.0
cost_p95_microdollars: float = 0.0
cost_p99_microdollars: float = 0.0
cost_buckets: list[CostBucket] = []
def _si(row: dict, field: str) -> int:
@@ -228,6 +245,66 @@ def _build_prisma_where(
return where
def _build_raw_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
) -> tuple[str, list]:
"""Build a parameterised WHERE clause for raw SQL queries.
Mirrors the filter logic of ``_build_prisma_where`` so there is a single
source of truth for which columns are filtered and how. The first clause
always restricts to ``cost_usd`` tracking type unless *tracking_type* is
explicitly provided by the caller.
"""
params: list = []
clauses: list[str] = []
idx = 1
# Always filter by tracking type — defaults to cost_usd for percentile /
# bucket queries that only make sense on cost-denominated rows.
tt = tracking_type if tracking_type is not None else "cost_usd"
clauses.append(f'"trackingType" = ${idx}')
params.append(tt)
idx += 1
if start is not None:
clauses.append(f'"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end is not None:
clauses.append(f'"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider is not None:
clauses.append(f'"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id is not None:
clauses.append(f'"userId" = ${idx}')
params.append(user_id)
idx += 1
if model is not None:
clauses.append(f'"model" = ${idx}')
params.append(model)
idx += 1
if block_name is not None:
clauses.append(f'LOWER("blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
return (" AND ".join(clauses), params)
@cached(ttl_seconds=30)
async def get_platform_cost_dashboard(
start: datetime | None = None,
@@ -256,6 +333,14 @@ async def get_platform_cost_dashboard(
start, end, provider, user_id, model, block_name, tracking_type
)
# For per-user tracking-type breakdown we intentionally omit the
# tracking_type filter so cost_usd and tokens rows are always present.
# This ensures cost_bearing_request_count is correct even when the caller
# is filtering the main view by a different tracking_type.
where_no_tracking_type = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type=None
)
sum_fields = {
"costMicrodollars": True,
"inputTokens": True,
@@ -266,13 +351,18 @@ async def get_platform_cost_dashboard(
"trackingAmount": True,
}
# Run all four aggregation queries in parallel.
(
by_provider_groups,
by_user_groups,
total_user_groups,
total_agg_groups,
) = await asyncio.gather(
# Build parameterised WHERE clause for the raw SQL percentile/bucket
# queries. Uses _build_raw_where so filter logic is shared with
# _build_prisma_where and only maintained in one place.
# Always force tracking_type=None here so _build_raw_where defaults to
# "cost_usd" — percentile and histogram queries only make sense on
# cost-denominated rows, regardless of what the caller is filtering.
raw_where, raw_params = _build_raw_where(
start, end, provider, user_id, model, block_name, tracking_type=None
)
# Queries that always run regardless of tracking_type filter.
common_queries = [
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
# sort by total cost descending in Python after fetch.
PrismaLog.prisma().group_by(
@@ -288,20 +378,125 @@ async def get_platform_cost_dashboard(
sum=sum_fields,
count=True,
),
# Per-user cost-bearing request count: group by (userId, trackingType)
# so we can compute the correct denominator for per-user avg cost.
# Uses where_no_tracking_type so cost_usd rows are always included
# even when the caller filters the main view by a different tracking_type.
PrismaLog.prisma().group_by(
by=["userId", "trackingType"],
where=where_no_tracking_type,
count=True,
),
# Distinct user count: group by userId, count groups.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
count=True,
),
# Total aggregate: group by provider (no limit) to sum across all
# matching rows. Summed in Python to get grand totals.
# Total aggregate (filtered): group by (provider, trackingType) so we can
# compute cost-bearing and token-bearing denominators for avg stats.
PrismaLog.prisma().group_by(
by=["provider"],
by=["provider", "trackingType"],
where=where,
sum={"costMicrodollars": True},
sum={
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
},
count=True,
),
# Percentile distribution of cost per request (respects all filters).
query_raw_with_schema(
"SELECT"
" percentile_cont(0.5) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p50,'
" percentile_cont(0.75) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p75,'
" percentile_cont(0.95) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p95,'
" percentile_cont(0.99) WITHIN GROUP"
' (ORDER BY "costMicrodollars") as p99'
' FROM {schema_prefix}"PlatformCostLog"'
f" WHERE {raw_where}",
*raw_params,
),
# Histogram buckets for cost distribution (respects all filters).
# NULL costMicrodollars is excluded explicitly to prevent such rows
# from falling through all WHEN clauses into the ELSE '$10+' bucket.
query_raw_with_schema(
"SELECT"
" CASE"
' WHEN "costMicrodollars" < 500000'
" THEN '$0-0.50'"
' WHEN "costMicrodollars" < 1000000'
" THEN '$0.50-1'"
' WHEN "costMicrodollars" < 2000000'
" THEN '$1-2'"
' WHEN "costMicrodollars" < 5000000'
" THEN '$2-5'"
' WHEN "costMicrodollars" < 10000000'
" THEN '$5-10'"
" ELSE '$10+'"
" END as bucket,"
" COUNT(*) as count"
' FROM {schema_prefix}"PlatformCostLog"'
f' WHERE {raw_where} AND "costMicrodollars" IS NOT NULL'
" GROUP BY bucket"
' ORDER BY MIN("costMicrodollars")',
*raw_params,
),
]
# Only run the unfiltered aggregate query when tracking_type is set;
# when tracking_type is None, the filtered query already contains all
# tracking types and reusing it avoids a redundant full aggregation.
if tracking_type is not None:
common_queries.append(
# Total aggregate (no tracking_type filter): used to compute
# cost_bearing_requests and token_bearing_requests denominators so
# global avg stats remain meaningful when the caller filters the
# main view by a specific tracking_type (e.g. 'tokens').
PrismaLog.prisma().group_by(
by=["provider", "trackingType"],
where=where_no_tracking_type,
sum={
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
},
count=True,
)
)
results = await asyncio.gather(*common_queries)
# Unpack results by name for clarity.
by_provider_groups = results[0]
by_user_groups = results[1]
by_user_tracking_groups = results[2]
total_user_groups = results[3]
total_agg_groups = results[4]
percentile_rows = results[5]
bucket_rows = results[6]
# When tracking_type is None, the filtered and unfiltered queries are
# identical — reuse total_agg_groups to avoid the extra DB round-trip.
total_agg_no_tracking_type_groups = (
results[7] if tracking_type is not None else total_agg_groups
)
# Compute token grand-totals from the unfiltered aggregate so they remain
# consistent with the avg-token stats (which also use unfiltered data).
# Using by_provider_groups here would give 0 tokens when tracking_type='cost_usd'
# because cost_usd rows carry no token data, contradicting non-zero averages.
total_input_tokens = sum(
_si(r, "inputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
total_output_tokens = sum(
_si(r, "outputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
@@ -328,6 +523,61 @@ async def get_platform_cost_dashboard(
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
total_requests = sum(_ca(r) for r in total_agg_groups)
# Extract percentile values from the raw query result.
pctl = percentile_rows[0] if percentile_rows else {}
cost_p50 = float(pctl.get("p50") or 0)
cost_p75 = float(pctl.get("p75") or 0)
cost_p95 = float(pctl.get("p95") or 0)
cost_p99 = float(pctl.get("p99") or 0)
# Build cost bucket list.
cost_buckets: list[CostBucket] = [
CostBucket(bucket=r["bucket"], count=int(r["count"])) for r in bucket_rows
]
# Avg-stat numerators and denominators are derived from the unfiltered
# aggregate so they remain meaningful when the caller filters by a specific
# tracking_type. Example: filtering by 'tokens' excludes cost_usd rows from
# total_agg_groups, so avg_cost would always be 0 if we used that; using
# total_agg_no_tracking_type_groups gives the correct cost_usd total/count.
avg_cost_total = sum(
_si(r, "costMicrodollars")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "cost_usd"
)
cost_bearing_requests = sum(
_ca(r)
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "cost_usd"
)
avg_input_total = sum(
_si(r, "inputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
avg_output_total = sum(
_si(r, "outputTokens")
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
# Token-bearing request count: only rows where trackingType == "tokens".
# Token averages must use this denominator; cost_usd rows do not carry tokens.
token_bearing_requests = sum(
_ca(r)
for r in total_agg_no_tracking_type_groups
if r.get("trackingType") == "tokens"
)
# Per-user cost-bearing request count: used for per-user avg cost so the
# denominator matches the numerator (cost_usd rows only, per user).
user_cost_bearing_counts: dict[str, int] = {}
for r in by_user_tracking_groups:
if r.get("trackingType") == "cost_usd" and r.get("userId"):
uid = r["userId"]
user_cost_bearing_counts[uid] = user_cost_bearing_counts.get(uid, 0) + _ca(
r
)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
@@ -355,12 +605,35 @@ async def get_platform_cost_dashboard(
total_cache_read_tokens=_si(r, "cacheReadTokens"),
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
request_count=_ca(r),
cost_bearing_request_count=user_cost_bearing_counts.get(
r.get("userId") or "", 0
),
)
for r in by_user_groups
],
total_cost_microdollars=total_cost,
total_requests=total_requests,
total_users=total_users,
total_input_tokens=total_input_tokens,
total_output_tokens=total_output_tokens,
avg_input_tokens_per_request=(
avg_input_total / token_bearing_requests
if token_bearing_requests > 0
else 0.0
),
avg_output_tokens_per_request=(
avg_output_total / token_bearing_requests
if token_bearing_requests > 0
else 0.0
),
avg_cost_microdollars_per_request=(
avg_cost_total / cost_bearing_requests if cost_bearing_requests > 0 else 0.0
),
cost_p50_microdollars=cost_p50,
cost_p75_microdollars=cost_p75,
cost_p95_microdollars=cost_p95,
cost_p99_microdollars=cost_p99,
cost_buckets=cost_buckets,
)

View File

@@ -10,6 +10,8 @@ from backend.util.json import SafeJson
from .platform_cost import (
PlatformCostEntry,
_build_prisma_where,
_build_raw_where,
_build_where,
_mask_email,
get_platform_cost_dashboard,
@@ -156,6 +158,84 @@ class TestBuildWhere:
assert 'p."trackingType" = $3' in sql
class TestBuildPrismaWhere:
def test_both_start_and_end(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
where = _build_prisma_where(start, end, None, None)
assert where["createdAt"] == {"gte": start, "lte": end}
def test_end_only(self):
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
where = _build_prisma_where(None, end, None, None)
assert where["createdAt"] == {"lte": end}
def test_start_only(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
where = _build_prisma_where(start, None, None, None)
assert where["createdAt"] == {"gte": start}
def test_no_filters(self):
where = _build_prisma_where(None, None, None, None)
assert "createdAt" not in where
def test_provider_lowercased(self):
where = _build_prisma_where(None, None, "OpenAI", None)
assert where["provider"] == "openai"
def test_model_filter(self):
where = _build_prisma_where(None, None, None, None, model="gpt-4")
assert where["model"] == "gpt-4"
def test_block_name_case_insensitive(self):
where = _build_prisma_where(None, None, None, None, block_name="LLMBlock")
assert where["blockName"] == {"equals": "LLMBlock", "mode": "insensitive"}
def test_tracking_type(self):
where = _build_prisma_where(None, None, None, None, tracking_type="tokens")
assert where["trackingType"] == "tokens"
class TestBuildRawWhere:
def test_end_filter(self):
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_raw_where(None, end, None, None)
assert '"createdAt" <= $2::timestamptz' in sql
assert end in params
def test_model_filter(self):
sql, params = _build_raw_where(None, None, None, None, model="gpt-4")
assert '"model" = $' in sql
assert "gpt-4" in params
def test_block_name_filter(self):
sql, params = _build_raw_where(None, None, None, None, block_name="LLMBlock")
assert 'LOWER("blockName") = LOWER($' in sql
assert "LLMBlock" in params
def test_all_filters_combined(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_raw_where(
start, end, "anthropic", "u1", model="claude-3", block_name="LLM"
)
# trackingType (default), start, end, provider, user_id, model, block_name
assert len(params) == 7
assert "anthropic" in params
assert "u1" in params
assert "claude-3" in params
assert "LLM" in params
def test_default_tracking_type_is_cost_usd(self):
sql, params = _build_raw_where(None, None, None, None)
assert '"trackingType" = $1' in sql
assert params[0] == "cost_usd"
def test_explicit_tracking_type_overrides_default(self):
sql, params = _build_raw_where(None, None, None, None, tracking_type="tokens")
assert params[0] == "tokens"
def _make_entry(**overrides: object) -> PlatformCostEntry:
return PlatformCostEntry.model_validate(
{
@@ -286,8 +366,9 @@ class TestGetPlatformCostDashboard:
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[], # by_user_tracking_groups (no cost_usd rows for this user)
[{"userId": "u1"}], # distinct users
[provider_row], # total agg
[provider_row], # total agg (tracking_type=None → same as unfiltered)
]
)
mock_actions.find_many = AsyncMock(return_value=[mock_user])
@@ -301,6 +382,14 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[
[{"p50": 1000, "p75": 2000, "p95": 4000, "p99": 5000}],
[{"bucket": "$0-0.50", "count": 3}],
],
),
):
dashboard = await get_platform_cost_dashboard()
@@ -313,6 +402,131 @@ class TestGetPlatformCostDashboard:
assert dashboard.by_provider[0].total_duration_seconds == 10.5
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].email == "a***@b.com"
assert dashboard.cost_p50_microdollars == 1000
assert dashboard.cost_p75_microdollars == 2000
assert dashboard.cost_p95_microdollars == 4000
assert dashboard.cost_p99_microdollars == 5000
assert len(dashboard.cost_buckets) == 1
# total_input/output_tokens come from total_agg_no_tracking_type_groups
# (provider_row has 1000/500)
assert dashboard.total_input_tokens == 1000
assert dashboard.total_output_tokens == 500
# Token averages must use token_bearing_requests (3) not cost_bearing (0)
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 3)
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 3)
# No cost_usd rows in total_agg → avg_cost should be 0
assert dashboard.avg_cost_microdollars_per_request == 0.0
@pytest.mark.asyncio
async def test_cost_bearing_request_count_nonzero_when_filtering_by_tokens(self):
"""When filtering by tracking_type='tokens', cost_bearing_request_count
must still reflect cost_usd rows because by_user_tracking_groups is
queried without the tracking_type constraint."""
# total_agg only has a tokens row (because of the tracking_type filter)
total_row = _make_group_by_row(
provider="openai", tracking_type="tokens", cost=0, count=5
)
# by_user_tracking_groups returns BOTH rows (no tracking_type filter)
user_tracking_cost_usd_row = {
"_count": {"_all": 7},
"userId": "u1",
"trackingType": "cost_usd",
}
user_tracking_tokens_row = {
"_count": {"_all": 5},
"userId": "u1",
"trackingType": "tokens",
}
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[total_row], # by_provider
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
[
user_tracking_cost_usd_row,
user_tracking_tokens_row,
], # by_user_tracking
[{"userId": "u1"}], # distinct users
[total_row], # total agg (filtered)
[total_row], # total agg (no tracking_type filter)
]
)
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
# by_user has 1 user with 5 total requests (tokens rows only due to filter)
# but per-user cost_bearing count should be 7 (from cost_usd rows in
# by_user_tracking_groups which uses where_no_tracking_type)
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].cost_bearing_request_count == 7
@pytest.mark.asyncio
async def test_global_avg_cost_nonzero_when_filtering_by_tokens(self):
"""When filtering by tracking_type='tokens', avg_cost_microdollars_per_request
must still reflect cost_usd rows from total_agg_no_tracking_type_groups,
not the filtered total_agg_groups which only has tokens rows."""
# filtered total_agg only has tokens rows (zero cost)
tokens_row = _make_group_by_row(
provider="openai", tracking_type="tokens", cost=0, count=5
)
# unfiltered total_agg has both rows (cost_usd carries the actual cost)
cost_usd_row = _make_group_by_row(
provider="openai", tracking_type="cost_usd", cost=10_000, count=4
)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[tokens_row], # by_provider
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
[], # by_user_tracking_groups
[{"userId": "u1"}], # distinct users
[tokens_row], # total agg (filtered — tokens only)
[tokens_row, cost_usd_row], # total agg (no tracking_type filter)
]
)
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
# avg_cost_microdollars_per_request must be non-zero: cost_usd row
# (10_000 microdollars, 4 requests) is present in the unfiltered agg.
assert dashboard.avg_cost_microdollars_per_request == pytest.approx(10_000 / 4)
# avg token stats use token_bearing_requests from unfiltered agg (5)
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 5)
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 5)
@pytest.mark.asyncio
async def test_cache_tokens_aggregated_not_hardcoded(self):
@@ -335,8 +549,9 @@ class TestGetPlatformCostDashboard:
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[], # by_user_tracking_groups
[{"userId": "u2"}], # distinct users
[provider_row], # total agg
[provider_row], # total agg (tracking_type=None → same as unfiltered)
]
)
mock_actions.find_many = AsyncMock(return_value=[])
@@ -350,6 +565,14 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[
[{"p50": 0, "p75": 0, "p95": 0, "p99": 0}],
[],
],
),
):
dashboard = await get_platform_cost_dashboard()
@@ -361,7 +584,7 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
@@ -373,6 +596,11 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
dashboard = await get_platform_cost_dashboard()
@@ -381,13 +609,56 @@ class TestGetPlatformCostDashboard:
assert dashboard.total_users == 0
assert dashboard.by_provider == []
assert dashboard.by_user == []
assert dashboard.cost_p50_microdollars == 0
assert dashboard.cost_buckets == []
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
raw_mock = AsyncMock(side_effect=[[], []])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
raw_mock,
),
):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users,
# total agg filtered); the 6th call (total agg no-tracking-type) only runs
# when tracking_type is set.
assert mock_actions.group_by.await_count == 5
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})
# Raw SQL queries should receive provider and user_id as parameters
assert raw_mock.await_count == 2
raw_call_args = raw_mock.call_args_list[0][0] # positional args of 1st call
raw_params = raw_call_args[1:] # first arg is the query template
assert "openai" in raw_params
assert "u1" in raw_params
@pytest.mark.asyncio
async def test_user_tracking_groups_excludes_tracking_type_filter(self):
"""by_user_tracking_groups must NOT apply the tracking_type filter so that
cost_usd rows are always included even when the caller filters by 'tokens'."""
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
@@ -399,16 +670,23 @@ class TestGetPlatformCostDashboard:
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.query_raw_with_schema",
new_callable=AsyncMock,
side_effect=[[], []],
),
):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
await get_platform_cost_dashboard(tracking_type="tokens")
# group_by called 4 times (by_provider, by_user, distinct users, totals)
assert mock_actions.group_by.await_count == 4
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})
# Call index 2 is by_user_tracking_groups (0=by_provider, 1=by_user,
# 2=by_user_tracking, 3=distinct_users, 4=total_agg, 5=total_agg_no_tt).
tracking_call_where = mock_actions.group_by.call_args_list[2][1]["where"]
# The main filter applies trackingType; by_user_tracking must NOT.
assert "trackingType" not in tracking_call_where
# Other filters (e.g., date range, provider) are still passed through.
# The first call (by_provider) should have trackingType in its where dict.
provider_call_where = mock_actions.group_by.call_args_list[0][1]["where"]
assert "trackingType" in provider_call_where
def _make_prisma_log_row(

View File

@@ -0,0 +1,509 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Any, cast
from backend.blocks import get_block
from backend.blocks._base import Block
from backend.blocks.io import AgentOutputBlock
from backend.data import redis_client as redis
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionStatus,
GraphExecutionEntry,
NodeExecutionEntry,
)
from backend.data.graph import Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
ZeroBalanceData,
)
from backend.notifications.notifications import queue_notification
from backend.util.clients import (
get_database_manager_client,
get_notification_manager_client,
)
from backend.util.exceptions import InsufficientBalanceError
from backend.util.logging import TruncatedLogger
from backend.util.metrics import DiscordChannel
from backend.util.settings import Settings
from .utils import LogMetadata, block_usage_cost, execution_usage_cost
if TYPE_CHECKING:
from backend.data.db_manager import DatabaseManagerClient
_logger = logging.getLogger(__name__)
logger = TruncatedLogger(_logger, prefix="[Billing]")
settings = Settings()
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
# Hard cap on the multiplier passed to charge_extra_runtime_cost to
# protect against a corrupted llm_call_count draining a user's balance.
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
# 200 leaves headroom while preventing runaway charges.
_MAX_EXTRA_RUNTIME_COST = 200
def get_db_client() -> "DatabaseManagerClient":
return get_database_manager_client()
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
def resolve_block_cost(
node_exec: NodeExecutionEntry,
) -> tuple["Block | None", int, dict[str, Any]]:
"""Look up the block and compute its base usage cost for an exec.
Shared by charge_usage and charge_extra_runtime_cost so the
(get_block, block_usage_cost) lookup lives in exactly one place.
Returns ``(block, cost, matching_filter)``. ``block`` is ``None`` if
the block id can't be resolved — callers should treat that as
"nothing to charge".
"""
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return None, 0, {}
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.inputs)
return block, cost, matching_filter
def charge_usage(
node_exec: NodeExecutionEntry,
execution_count: int,
) -> tuple[int, int]:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block, cost, matching_filter = resolve_block_cost(node_exec)
if not block:
return total_cost, 0
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
total_cost += cost
# execution_count=0 is used by charge_node_usage for nested tool calls
# which must not be pushed into higher execution-count tiers.
# execution_usage_cost(0) would trigger a charge because 0 % threshold == 0,
# so skip it entirely when execution_count is 0.
cost, usage_count = (
execution_usage_cost(execution_count) if execution_count > 0 else (0, 0)
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
total_cost += cost
return total_cost, remaining_balance
def _charge_extra_runtime_cost_sync(
node_exec: NodeExecutionEntry,
capped_count: int,
) -> tuple[int, int]:
"""Synchronous implementation — runs in a thread-pool worker.
Called only from charge_extra_runtime_cost. Do not call directly from
async code.
Note: ``resolve_block_cost`` is called again here (rather than reusing
the result from ``charge_usage`` at the start of execution) because the
two calls happen in separate thread-pool workers and sharing mutable
state across workers would require locks. The block config is immutable
during a run, so the repeated lookup is safe and produces the same cost;
the only overhead is an extra registry lookup.
"""
db_client = get_db_client()
block, cost, matching_filter = resolve_block_cost(node_exec)
if not block or cost <= 0:
return 0, 0
total_extra_cost = cost * capped_count
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=total_extra_cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input={
**matching_filter,
"extra_runtime_cost_count": capped_count,
},
reason=(
f"Extra agent-mode iterations for {block.name} "
f"({capped_count} additional LLM calls)"
),
),
)
return total_extra_cost, remaining_balance
async def charge_extra_runtime_cost(
node_exec: NodeExecutionEntry,
extra_count: int,
) -> tuple[int, int]:
"""Charge a block extra runtime cost beyond the initial run.
Used by agent-mode blocks (e.g. OrchestratorBlock) that make multiple
LLM calls within a single node execution. The first iteration is already
charged by charge_usage; this method charges *extra_count* additional
copies of the block's base cost.
Returns ``(total_extra_cost, remaining_balance)``. May raise
``InsufficientBalanceError`` if the user can't afford the charge.
"""
if extra_count <= 0:
return 0, 0
# Cap to protect against a corrupted llm_call_count.
capped = min(extra_count, _MAX_EXTRA_RUNTIME_COST)
if extra_count > _MAX_EXTRA_RUNTIME_COST:
logger.warning(
f"extra_count {extra_count} exceeds cap {_MAX_EXTRA_RUNTIME_COST};"
f" charging {_MAX_EXTRA_RUNTIME_COST} (llm_call_count may be corrupted)"
)
return await asyncio.to_thread(_charge_extra_runtime_cost_sync, node_exec, capped)
async def charge_node_usage(node_exec: NodeExecutionEntry) -> tuple[int, int]:
"""Charge a single node execution to the user.
Public async wrapper around charge_usage for blocks (e.g. the
OrchestratorBlock) that spawn nested node executions outside the main
queue and therefore need to charge them explicitly.
Also handles low-balance notification so callers don't need to touch
private functions directly.
Note: this **does not** increment the global execution counter
(``increment_execution_count``). Nested tool executions are sub-steps
of a single block run from the user's perspective and should not push
them into higher per-execution cost tiers.
"""
def _run():
total_cost, remaining = charge_usage(node_exec, 0)
if total_cost > 0:
handle_low_balance(
get_db_client(), node_exec.user_id, remaining, total_cost
)
return total_cost, remaining
return await asyncio.to_thread(_run)
async def try_send_insufficient_funds_notif(
user_id: str,
graph_id: str,
error: InsufficientBalanceError,
log_metadata: LogMetadata,
) -> None:
"""Send an insufficient-funds notification, swallowing failures."""
try:
await asyncio.to_thread(
handle_insufficient_funds_notif,
get_db_client(),
user_id,
graph_id,
error,
)
except Exception as notif_error: # pragma: no cover
log_metadata.warning(
f"Failed to send insufficient funds notification: {notif_error}"
)
async def handle_post_execution_billing(
node: Node,
node_exec: NodeExecutionEntry,
execution_stats: NodeExecutionStats,
status: ExecutionStatus,
log_metadata: LogMetadata,
) -> None:
"""Charge extra runtime cost for blocks that opt into per-LLM-call billing.
The first LLM call is already covered by charge_usage(); each additional
call costs another base_cost. Skipped for dry runs and failed runs.
InsufficientBalanceError here is a post-hoc billing leak: the work is
already done but the user can no longer pay. The run stays COMPLETED and
the error is logged with ``billing_leak: True`` for alerting.
"""
extra_iterations = (
cast(Block, node.block).extra_runtime_cost(execution_stats)
if status == ExecutionStatus.COMPLETED
and not node_exec.execution_context.dry_run
else 0
)
if extra_iterations <= 0:
return
try:
extra_cost, remaining_balance = await charge_extra_runtime_cost(
node_exec,
extra_iterations,
)
if extra_cost > 0:
execution_stats.extra_cost += extra_cost
await asyncio.to_thread(
handle_low_balance,
get_db_client(),
node_exec.user_id,
remaining_balance,
extra_cost,
)
except InsufficientBalanceError as e:
log_metadata.error(
"billing_leak: insufficient balance after "
f"{node.block.name} completed {extra_iterations} "
f"extra iterations",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_runtime_cost_count": extra_iterations,
"error": str(e),
},
)
# Do NOT set execution_stats.error — the node ran to completion,
# only the post-hoc charge failed. See class-level billing-leak
# contract documentation.
await try_send_insufficient_funds_notif(
node_exec.user_id,
node_exec.graph_id,
e,
log_metadata,
)
except Exception as e:
log_metadata.error(
f"billing_leak: failed to charge extra iterations for {node.block.name}",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_runtime_cost_count": extra_iterations,
"error_type": type(e).__name__,
"error": str(e),
},
exc_info=True,
)
def handle_agent_run_notif(
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
) -> None:
metadata = db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
]
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)
def handle_insufficient_funds_notif(
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
e: InsufficientBalanceError,
) -> None:
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.ZERO_BALANCE,
data=ZeroBalanceData(
current_balance=e.balance,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
),
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as alert_error:
logger.error(f"Failed to send insufficient funds Discord alert: {alert_error}")
def handle_low_balance(
db_client: "DatabaseManagerClient",
user_id: str,
current_balance: int,
transaction_cost: int,
) -> None:
"""Check and handle low balance scenarios after a transaction"""
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
balance_before = current_balance + transaction_cost
if (
current_balance < LOW_BALANCE_THRESHOLD
and balance_before >= LOW_BALANCE_THRESHOLD
):
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=current_balance,
billing_page_link=f"{base_url}/profile/credits",
),
)
)
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.warning(f"Failed to send low balance Discord alert: {e}")

View File

@@ -21,11 +21,9 @@ from sentry_sdk.api import get_current_scope as _sentry_get_current_scope
from backend.blocks import get_block
from backend.blocks._base import BlockSchema
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.blocks.mcp.block import MCPToolBlock
from backend.data import redis_client as redis
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.credit import UsageTransactionMetadata
from backend.data.dynamic_fields import parse_execution_output
from backend.data.execution import (
ExecutionContext,
@@ -39,27 +37,18 @@ from backend.data.execution import (
)
from backend.data.graph import Link, Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.cost_tracking import (
drain_pending_cost_logs,
log_system_credential_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
get_notification_manager_client,
)
from backend.util.decorator import (
async_error_logged,
@@ -75,7 +64,6 @@ from backend.util.exceptions import (
)
from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import (
continuous_retry,
@@ -84,6 +72,7 @@ from backend.util.retry import (
)
from backend.util.settings import Settings
from . import billing
from .activity_status_generator import generate_activity_status_for_execution
from .automod.manager import automod_manager
from .cluster_lock import ClusterLock
@@ -98,9 +87,7 @@ from .utils import (
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
@@ -126,40 +113,6 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -681,12 +634,16 @@ class ExecutionProcessor:
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
await billing.handle_post_execution_billing(
node, node_exec, execution_stats, status, log_metadata
)
graph_stats, graph_stats_lock = graph_stats_pair
with graph_stats_lock:
graph_stats.node_count += 1 + execution_stats.extra_steps
graph_stats.nodes_cputime += execution_stats.cputime
graph_stats.nodes_walltime += execution_stats.walltime
graph_stats.cost += execution_stats.extra_cost
graph_stats.cost += execution_stats.cost + execution_stats.extra_cost
if isinstance(execution_stats.error, Exception):
graph_stats.node_error_count += 1
@@ -716,6 +673,18 @@ class ExecutionProcessor:
db_client=db_client,
)
# If the node failed because a nested tool charge raised IBE,
# send the user notification so they understand why the run stopped.
if status == ExecutionStatus.FAILED and isinstance(
execution_stats.error, InsufficientBalanceError
):
await billing.try_send_insufficient_funds_notif(
node_exec.user_id,
node_exec.graph_id,
execution_stats.error,
log_metadata,
)
return execution_stats
@async_time_measured
@@ -935,7 +904,7 @@ class ExecutionProcessor:
)
finally:
# Communication handling
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
billing.handle_agent_run_notif(db_client, graph_exec, exec_stats)
update_graph_execution_state(
db_client=db_client,
@@ -944,57 +913,18 @@ class ExecutionProcessor:
stats=exec_stats,
)
def _charge_usage(
async def charge_node_usage(
self,
node_exec: NodeExecutionEntry,
execution_count: int,
) -> tuple[int, int]:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return total_cost, 0
return await billing.charge_node_usage(node_exec)
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
total_cost += cost
cost, usage_count = execution_usage_cost(execution_count)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
total_cost += cost
return total_cost, remaining_balance
async def charge_extra_runtime_cost(
self,
node_exec: NodeExecutionEntry,
extra_count: int,
) -> tuple[int, int]:
return await billing.charge_extra_runtime_cost(node_exec, extra_count)
@time_measured
def _on_graph_execution(
@@ -1106,7 +1036,7 @@ class ExecutionProcessor:
# Charge usage (may raise) — skipped for dry runs
try:
if not graph_exec.execution_context.dry_run:
cost, remaining_balance = self._charge_usage(
cost, remaining_balance = billing.charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(
graph_exec.user_id
@@ -1115,7 +1045,7 @@ class ExecutionProcessor:
with execution_stats_lock:
execution_stats.cost += cost
# Check if we crossed the low balance threshold
self._handle_low_balance(
billing.handle_low_balance(
db_client=db_client,
user_id=graph_exec.user_id,
current_balance=remaining_balance,
@@ -1135,7 +1065,7 @@ class ExecutionProcessor:
status=ExecutionStatus.FAILED,
)
self._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client,
graph_exec.user_id,
graph_exec.graph_id,
@@ -1397,165 +1327,6 @@ class ExecutionProcessor:
):
execution_queue.add(next_execution)
def _handle_agent_run_notif(
self,
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
]
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)
def _handle_insufficient_funds_notif(
self,
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.ZERO_BALANCE,
data=ZeroBalanceData(
current_balance=e.balance,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
),
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as alert_error:
logger.error(
f"Failed to send insufficient funds Discord alert: {alert_error}"
)
def _handle_low_balance(
self,
db_client: "DatabaseManagerClient",
user_id: str,
current_balance: int,
transaction_cost: int,
):
"""Check and handle low balance scenarios after a transaction"""
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
balance_before = current_balance + transaction_cost
if (
current_balance < LOW_BALANCE_THRESHOLD
and balance_before >= LOW_BALANCE_THRESHOLD
):
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=current_balance,
billing_page_link=f"{base_url}/profile/credits",
),
)
)
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.warning(f"Failed to send low balance Discord alert: {e}")
class ExecutionManager(AppProcess):
def __init__(self):

View File

@@ -4,9 +4,9 @@ import pytest
from prisma.enums import NotificationType
from backend.data.notifications import ZeroBalanceData
from backend.executor.manager import (
from backend.executor import billing
from backend.executor.billing import (
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
ExecutionProcessor,
clear_insufficient_funds_notifications,
)
from backend.util.exceptions import InsufficientBalanceError
@@ -25,7 +25,6 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
):
"""Test that the first insufficient funds notification sends a Discord alert."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
@@ -36,13 +35,13 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
)
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Setup mocks
@@ -63,7 +62,7 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
@@ -99,7 +98,6 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
):
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
@@ -110,13 +108,13 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
)
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Setup mocks
@@ -134,7 +132,7 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
@@ -154,7 +152,6 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
):
"""Test that different agents for the same user get separate Discord alerts."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id_1 = "test-graph-111"
graph_id_2 = "test-graph-222"
@@ -166,12 +163,12 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
with patch("backend.executor.billing.queue_notification"), patch(
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
mock_client = MagicMock()
@@ -190,7 +187,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# First agent notification
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_1,
@@ -198,7 +195,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
)
# Second agent notification
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_2,
@@ -227,7 +224,7 @@ async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
with patch("backend.executor.billing.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
@@ -263,7 +260,7 @@ async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestSe
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
with patch("backend.executor.billing.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
@@ -290,7 +287,7 @@ async def test_clear_insufficient_funds_notifications_handles_redis_error(
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
with patch("backend.executor.billing.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
@@ -310,7 +307,6 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
):
"""Test that both email and Discord notifications are still sent when Redis fails."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
@@ -321,13 +317,13 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
)
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
mock_client = MagicMock()
@@ -346,7 +342,7 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
billing.handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
@@ -370,7 +366,7 @@ async def test_add_transaction_clears_notifications_on_grant(server: SpinTestSer
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -412,7 +408,7 @@ async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestSe
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -450,7 +446,7 @@ async def test_add_transaction_skips_clearing_for_inactive_transaction(
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -486,7 +482,7 @@ async def test_add_transaction_skips_clearing_for_usage_transaction(
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
"backend.executor.billing.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
@@ -521,7 +517,7 @@ async def test_enable_transaction_clears_notifications(server: SpinTestServer):
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
) as mock_query, patch("backend.executor.billing.redis") as mock_redis_module:
# Mock finding the pending transaction
mock_transaction = MagicMock()

View File

@@ -4,26 +4,25 @@ import pytest
from prisma.enums import NotificationType
from backend.data.notifications import LowBalanceData
from backend.executor.manager import ExecutionProcessor
from backend.executor import billing
from backend.util.test import SpinTestServer
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
"""Test that _handle_low_balance triggers notification when crossing threshold."""
"""Test that handle_low_balance triggers notification when crossing threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 400 # $4 - below $5 threshold
transaction_cost = 600 # $6 transaction
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings:
# Setup mocks
@@ -37,7 +36,7 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the low balance handler
execution_processor._handle_low_balance(
billing.handle_low_balance(
db_client=mock_db_client,
user_id=user_id,
current_balance=current_balance,
@@ -69,7 +68,6 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
):
"""Test that no notification is sent when not crossing the threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 600 # $6 - above $5 threshold
transaction_cost = (
@@ -78,11 +76,11 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings:
# Setup mocks
@@ -94,7 +92,7 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
mock_db_client = MagicMock()
# Test the low balance handler
execution_processor._handle_low_balance(
billing.handle_low_balance(
db_client=mock_db_client,
user_id=user_id,
current_balance=current_balance,
@@ -112,7 +110,6 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
):
"""Test that no notification is sent when already below threshold."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
current_balance = 300 # $3 - below $5 threshold
transaction_cost = (
@@ -121,11 +118,11 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
# Mock dependencies
with patch(
"backend.executor.manager.queue_notification"
"backend.executor.billing.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
"backend.executor.billing.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
"backend.executor.billing.settings"
) as mock_settings:
# Setup mocks
@@ -137,7 +134,7 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
mock_db_client = MagicMock()
# Test the low balance handler
execution_processor._handle_low_balance(
billing.handle_low_balance(
db_client=mock_db_client,
user_id=user_id,
current_balance=current_balance,

View File

@@ -18,9 +18,13 @@ images: {
"""
import asyncio
import json
import random
from pathlib import Path
from typing import Any, Dict, List
import prisma.enums as prisma_enums
import prisma.models as prisma_models
from faker import Faker
# Import API functions from the backend
@@ -30,10 +34,12 @@ from backend.api.features.store.db import (
create_store_submission,
review_store_submission,
)
from backend.api.features.store.model import StoreSubmission
from backend.blocks.io import AgentInputBlock
from backend.data.auth.api_key import create_api_key
from backend.data.credit import get_user_credit_model
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.graph import Graph, Link, Node, create_graph, make_graph_model
from backend.data.user import get_or_create_user
from backend.util.clients import get_supabase
@@ -60,6 +66,31 @@ MAX_REVIEWS_PER_VERSION = 5
GUARANTEED_FEATURED_AGENTS = 8
GUARANTEED_FEATURED_CREATORS = 5
GUARANTEED_TOP_AGENTS = 10
E2E_MARKETPLACE_CREATOR_EMAIL = "test123@example.com"
E2E_MARKETPLACE_CREATOR_USERNAME = "e2e-marketplace"
E2E_MARKETPLACE_AGENT_SLUG = "e2e-calculator-agent"
E2E_MARKETPLACE_AGENT_NAME = "E2E Calculator Agent"
E2E_MARKETPLACE_AGENT_INPUT_VALUE = 8
E2E_MARKETPLACE_AGENT_OUTPUT_VALUE = 42
_LOCAL_TEMPLATE_PATH = (
Path(__file__).resolve().parents[1] / "agents" / "calculator-agent.json"
)
_DOCKER_TEMPLATE_PATH = Path(
"/app/autogpt_platform/backend/agents/calculator-agent.json"
)
E2E_MARKETPLACE_AGENT_TEMPLATE_PATH = (
_LOCAL_TEMPLATE_PATH if _LOCAL_TEMPLATE_PATH.exists() else _DOCKER_TEMPLATE_PATH
)
SEEDED_TEST_EMAILS = [
"test123@example.com",
"e2e.qa.auth@example.com",
"e2e.qa.builder@example.com",
"e2e.qa.library@example.com",
"e2e.qa.marketplace@example.com",
"e2e.qa.settings@example.com",
"e2e.qa.parallel.a@example.com",
"e2e.qa.parallel.b@example.com",
]
def get_image():
@@ -100,6 +131,25 @@ def get_category():
return random.choice(categories)
def load_deterministic_marketplace_graph() -> Graph:
graph = Graph.model_validate(
json.loads(E2E_MARKETPLACE_AGENT_TEMPLATE_PATH.read_text())
)
graph.name = E2E_MARKETPLACE_AGENT_NAME
graph.description = (
"Deterministic marketplace calculator graph for Playwright PR E2E coverage."
)
for node in graph.nodes:
if (
node.block_id == AgentInputBlock().id
and node.input_default.get("value") is None
):
node.input_default["value"] = E2E_MARKETPLACE_AGENT_INPUT_VALUE
return graph
class TestDataCreator:
"""Creates test data using API functions for E2E tests."""
@@ -123,9 +173,9 @@ class TestDataCreator:
for i in range(NUM_USERS):
try:
# Generate test user data
if i == 0:
# First user should have test123@gmail.com email for testing
email = "test123@gmail.com"
if i < len(SEEDED_TEST_EMAILS):
# Keep a deterministic pool for Playwright global setup and PR smoke flows
email = SEEDED_TEST_EMAILS[i]
else:
email = faker.unique.email()
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
@@ -547,6 +597,46 @@ class TestDataCreator:
print(f"Error updating profile {profile.id}: {e}")
continue
deterministic_creator = next(
(
user
for user in self.users
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
),
None,
)
if deterministic_creator:
deterministic_profile = next(
(
profile
for profile in existing_profiles
if profile.userId == deterministic_creator["id"]
),
None,
)
if deterministic_profile:
try:
updated_profile = await prisma.profile.update(
where={"id": deterministic_profile.id},
data={
"name": "E2E Marketplace Creator",
"username": E2E_MARKETPLACE_CREATOR_USERNAME,
"description": "Deterministic marketplace creator for Playwright PR E2E coverage.",
"links": ["https://example.com/e2e-marketplace"],
"avatarUrl": get_image(),
"isFeatured": True,
},
)
profiles = [
profile
for profile in profiles
if profile.get("id") != deterministic_profile.id
]
if updated_profile is not None:
profiles.append(updated_profile.model_dump())
except Exception as e:
print(f"Error updating deterministic E2E creator profile: {e}")
self.profiles = profiles
return profiles
@@ -562,58 +652,184 @@ class TestDataCreator:
featured_count = 0
submission_counter = 0
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
# Create a deterministic calculator marketplace agent for PR E2E coverage
test_user = next(
(user for user in self.users if user["email"] == "test123@gmail.com"), None
(
user
for user in self.users
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
),
None,
)
if test_user and self.agent_graphs:
test_submission_data = {
"user_id": test_user["id"],
"graph_id": self.agent_graphs[0]["id"],
"graph_version": 1,
"slug": "test-agent-submission",
"name": "Test Agent Submission",
"sub_heading": "A test agent for frontend testing",
"video_url": "https://www.youtube.com/watch?v=test123",
"image_urls": [
"https://picsum.photos/200/300",
"https://picsum.photos/200/301",
"https://picsum.photos/200/302",
],
"description": "This is a test agent submission specifically created for frontend testing purposes.",
"categories": ["test", "demo", "frontend"],
"changes_summary": "Initial test submission",
}
if test_user:
deterministic_graph = None
try:
test_submission = await create_store_submission(**test_submission_data)
submissions.append(test_submission.model_dump())
print("✅ Created special test store submission for test123@gmail.com")
# ALWAYS approve and feature the test submission
if test_submission.listing_version_id:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.listing_version_id,
is_approved=True,
external_comments="Test submission approved",
internal_comments="Auto-approved test submission",
reviewer_id=test_user["id"],
existing_graph = await prisma_models.AgentGraph.prisma().find_first(
where={
"userId": test_user["id"],
"name": E2E_MARKETPLACE_AGENT_NAME,
"isActive": True,
},
order={"version": "desc"},
)
if existing_graph:
deterministic_graph = {
"id": existing_graph.id,
"version": existing_graph.version,
"name": existing_graph.name,
"userId": test_user["id"],
}
self.agent_graphs.append(deterministic_graph)
print(
"✅ Reused existing deterministic marketplace graph: "
f"{existing_graph.id}"
)
approved_submissions.append(approved_submission.model_dump())
print("✅ Approved test store submission")
await prisma.storelistingversion.update(
where={"id": test_submission.listing_version_id},
data={"isFeatured": True},
else:
deterministic_graph_model = make_graph_model(
load_deterministic_marketplace_graph(),
test_user["id"],
)
featured_count += 1
print("🌟 Marked test agent as FEATURED")
deterministic_graph_model.reassign_ids(
user_id=test_user["id"],
reassign_graph_id=True,
)
created_deterministic_graph = await create_graph(
deterministic_graph_model,
test_user["id"],
)
deterministic_graph = created_deterministic_graph.model_dump()
deterministic_graph["userId"] = test_user["id"]
self.agent_graphs.append(deterministic_graph)
print("✅ Created deterministic marketplace graph")
except Exception as e:
print(f"Error creating test store submission: {e}")
import traceback
print(f"Error creating deterministic marketplace graph: {e}")
traceback.print_exc()
if deterministic_graph is None and self.agent_graphs:
test_user_graphs = [
graph
for graph in self.agent_graphs
if graph.get("userId") == test_user["id"]
]
deterministic_graph = next(
(
graph
for graph in test_user_graphs
if not graph.get("name", "").startswith("DummyInput ")
),
test_user_graphs[0] if test_user_graphs else None,
)
if deterministic_graph:
test_submission_data = {
"user_id": test_user["id"],
"graph_id": deterministic_graph["id"],
"graph_version": deterministic_graph.get("version", 1),
"slug": E2E_MARKETPLACE_AGENT_SLUG,
"name": E2E_MARKETPLACE_AGENT_NAME,
"sub_heading": "A deterministic calculator agent for PR E2E coverage",
"video_url": "https://www.youtube.com/watch?v=test123",
"image_urls": [
"https://picsum.photos/seed/e2e-marketplace-1/200/300",
"https://picsum.photos/seed/e2e-marketplace-2/200/301",
"https://picsum.photos/seed/e2e-marketplace-3/200/302",
],
"description": (
"A deterministic marketplace calculator agent that adds "
f"{E2E_MARKETPLACE_AGENT_INPUT_VALUE} and 34 to produce "
f"{E2E_MARKETPLACE_AGENT_OUTPUT_VALUE} for frontend E2E coverage."
),
"categories": ["test", "demo", "frontend"],
"changes_summary": (
"Initial deterministic calculator submission seeded from "
"backend/agents/calculator-agent.json"
),
}
try:
existing_deterministic_submission = (
await prisma_models.StoreListingVersion.prisma().find_first(
where={
"isDeleted": False,
"StoreListing": {
"is": {
"owningUserId": test_user["id"],
"slug": E2E_MARKETPLACE_AGENT_SLUG,
"isDeleted": False,
}
},
},
include={"StoreListing": True},
order={"version": "desc"},
)
)
if existing_deterministic_submission:
test_submission = StoreSubmission.from_listing_version(
existing_deterministic_submission
)
submissions.append(test_submission.model_dump())
print(
"✅ Reused deterministic marketplace submission: "
f"{E2E_MARKETPLACE_AGENT_NAME}"
)
else:
test_submission = await create_store_submission(
**test_submission_data
)
submissions.append(test_submission.model_dump())
print(
"✅ Created deterministic marketplace submission: "
f"{E2E_MARKETPLACE_AGENT_NAME}"
)
current_status = (
existing_deterministic_submission.submissionStatus
if existing_deterministic_submission
else test_submission.status
)
is_featured = bool(
existing_deterministic_submission
and existing_deterministic_submission.isFeatured
)
if test_submission.listing_version_id:
if current_status != prisma_enums.SubmissionStatus.APPROVED:
approved_submission = await review_store_submission(
store_listing_version_id=test_submission.listing_version_id,
is_approved=True,
external_comments="Deterministic calculator submission approved",
internal_comments="Auto-approved PR E2E marketplace submission",
reviewer_id=test_user["id"],
)
approved_submissions.append(
approved_submission.model_dump()
)
print("✅ Approved deterministic marketplace submission")
else:
approved_submissions.append(test_submission.model_dump())
print(
"✅ Deterministic marketplace submission already approved"
)
if is_featured:
featured_count += 1
print("🌟 Deterministic marketplace agent already FEATURED")
else:
await prisma.storelistingversion.update(
where={"id": test_submission.listing_version_id},
data={"isFeatured": True},
)
featured_count += 1
print(
"🌟 Marked deterministic marketplace agent as FEATURED"
)
except Exception as e:
print(f"Error creating deterministic marketplace submission: {e}")
import traceback
traceback.print_exc()
# Create regular submissions for all users
for user in self.users:

View File

@@ -6,7 +6,8 @@
# 5. CLI arguments - docker compose run -e VAR=value
# Common backend environment - Docker service names
x-backend-env: &backend-env # Docker internal service hostnames (override localhost defaults)
x-backend-env:
&backend-env # Docker internal service hostnames (override localhost defaults)
PYRO_HOST: "0.0.0.0"
AGENTSERVER_HOST: rest_server
SCHEDULER_HOST: scheduler_server
@@ -39,7 +40,12 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: migrate
command: ["sh", "-c", "prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy"]
command:
[
"sh",
"-c",
"prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy",
]
develop:
watch:
- path: ./
@@ -79,8 +85,8 @@ services:
falkordb:
image: falkordb/falkordb:latest
ports:
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
- "3001:3000" # FalkorDB web UI
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
- "3001:3000" # FalkorDB web UI
environment:
- REDIS_ARGS=--requirepass ${GRAPHITI_FALKORDB_PASSWORD:-}
volumes:
@@ -88,7 +94,11 @@ services:
networks:
- app-network
healthcheck:
test: ["CMD-SHELL", "redis-cli -p 6379 -a \"${GRAPHITI_FALKORDB_PASSWORD:-}\" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1"]
test:
[
"CMD-SHELL",
'redis-cli -p 6379 -a "${GRAPHITI_FALKORDB_PASSWORD:-}" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1',
]
interval: 10s
timeout: 5s
retries: 5
@@ -300,19 +310,6 @@ services:
condition: service_completed_successfully
database_manager:
condition: service_started
# healthcheck:
# test:
# [
# "CMD",
# "curl",
# "-f",
# "-X",
# "POST",
# "http://localhost:8003/health_check",
# ]
# interval: 10s
# timeout: 10s
# retries: 5
<<: *backend-env-files
environment:
<<: *backend-env

View File

@@ -193,3 +193,4 @@ services:
- copilot_executor
- websocket_server
- database_manager
- scheduler_server

View File

@@ -8,6 +8,7 @@ const config: StorybookConfig = {
"../src/components/molecules/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/ai-elements/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/renderers/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/app/[(]platform[)]/copilot/**/*.stories.@(js|jsx|mjs|ts|tsx)",
],
addons: [
"@storybook/addon-a11y",

View File

@@ -81,8 +81,10 @@ Every time a new Front-end dependency is added by you or others, you will need t
- `pnpm lint` - Run ESLint and Prettier checks
- `pnpm format` - Format code with Prettier
- `pnpm types` - Run TypeScript type checking
- `pnpm test` - Run Playwright tests
- `pnpm test-ui` - Run Playwright tests with UI
- `pnpm test:unit` - Run the Vitest integration and unit suite with coverage
- `pnpm test` - Run the Playwright E2E suite used in CI
- `pnpm test-ui` - Run the same Playwright E2E suite with UI
- `pnpm test:e2e:no-build` - Run the same Playwright E2E suite against a running app
- `pnpm fetch:openapi` - Fetch OpenAPI spec from backend
- `pnpm generate:api-client` - Generate API client from OpenAPI spec
- `pnpm generate:api` - Fetch OpenAPI spec and generate API client

View File

@@ -121,35 +121,49 @@ Only when the component has complex internal logic that is hard to exercise thro
### Running
```bash
pnpm test # build + run all Playwright tests
pnpm test-ui # run with Playwright UI
pnpm test:no-build # run against a running dev server
pnpm test # build + run the Playwright E2E suite used in CI
pnpm test-ui # run the same E2E suite with Playwright UI
pnpm test:e2e:no-build # run the same E2E suite against a running dev server
pnpm exec playwright test # run the same eight-spec Playwright suite directly
```
### Setup
1. Start the backend + Supabase stack:
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
2. Seed rich E2E data (creates `test123@example.com` with library agents):
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
### How Playwright setup works
- Playwright runs from `frontend/playwright.config.ts` with a global setup step
- Global setup creates a user pool via the real signup UI, stored in `frontend/.auth/user-pool.json`
- `getTestUser()` (from `src/tests/utils/auth.ts`) pulls a random user from the pool
- Playwright runs from `frontend/playwright.config.ts` and keeps browser-only code in `frontend/src/playwright/`
- Global setup creates reusable auth states for deterministic seeded accounts in `frontend/.auth/states/`
- `getTestUser()` (from `src/playwright/utils/auth.ts`) picks one seeded account for general auth coverage
- `getTestUserWithLibraryAgents()` uses the rich user created by the data script
### Test users
- **User pool (basic users)** — created automatically by Playwright global setup. Used by `getTestUser()`
- **Seeded E2E accounts** — created by backend fixtures and logged in during Playwright global setup. Used by `getTestUser()` and `E2E_AUTH_STATES`
- **Rich user with library agents** — created by `backend/test/e2e_test_data.py`. Used by `getTestUserWithLibraryAgents()`
### Current Playwright E2E suite
The CI suite is intentionally limited to the cross-page journeys we still require a real browser for. Playwright discovers the PR-gating specs by the `*-happy-path.spec.ts` naming pattern inside `src/playwright/`:
- `src/playwright/auth-happy-path.spec.ts`
- `src/playwright/settings-happy-path.spec.ts`
- `src/playwright/api-keys-happy-path.spec.ts`
- `src/playwright/builder-happy-path.spec.ts`
- `src/playwright/library-happy-path.spec.ts`
- `src/playwright/marketplace-happy-path.spec.ts`
- `src/playwright/publish-happy-path.spec.ts`
- `src/playwright/copilot-happy-path.spec.ts`
### Resetting the DB
If you reset the Docker DB and logins start failing:
1. Delete `frontend/.auth/user-pool.json`
1. Delete `frontend/.auth/states/*` and `frontend/.auth/user-pool.json` if it exists
2. Re-run `poetry run python test/e2e_test_data.py`
## Storybook

View File

@@ -13,11 +13,13 @@
"lint": "next lint && prettier --check .",
"format": "next lint --fix; prettier --write .",
"types": "tsc --noEmit",
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:ui",
"test:unit": "vitest run --coverage",
"test:unit:watch": "vitest",
"test:no-build": "playwright test",
"test:e2e": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
"test:e2e:no-build": "playwright test",
"test:e2e:ui": "playwright test --ui",
"gentests": "playwright codegen http://localhost:3000",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",

View File

@@ -7,10 +7,22 @@ import { defineConfig, devices } from "@playwright/test";
import dotenv from "dotenv";
import fs from "fs";
import path from "path";
import { buildCookieConsentStorageState } from "./src/playwright/credentials/storage-state";
dotenv.config({ path: path.resolve(__dirname, ".env") });
dotenv.config({ path: path.resolve(__dirname, "../backend/.env") });
const frontendRoot = __dirname.replaceAll("\\", "/");
const configuredBaseURL =
process.env.PLAYWRIGHT_BASE_URL ?? "http://localhost:3000";
const parsedBaseURL = new URL(configuredBaseURL);
const baseURL = parsedBaseURL.toString().replace(/\/$/, "");
const baseOrigin = parsedBaseURL.origin;
const jsonReporterOutputFile = process.env.PLAYWRIGHT_JSON_OUTPUT_FILE;
const configuredWorkers = process.env.PLAYWRIGHT_WORKERS
? Number(process.env.PLAYWRIGHT_WORKERS)
: process.env.CI
? 8
: undefined;
// Directory where CI copies .next/static from the Docker container
const staticCoverageDir = path.resolve(__dirname, ".next-static-coverage");
@@ -57,17 +69,18 @@ function resolveSourceMap(sourcePath: string) {
}
export default defineConfig({
testDir: "./src/tests",
testDir: "./src/playwright",
testMatch: /.*-happy-path\.spec\.ts/,
/* Global setup file that runs before all tests */
globalSetup: "./src/tests/global-setup.ts",
globalSetup: "./src/playwright/global-setup.ts",
/* Run tests in files in parallel */
fullyParallel: true,
/* Fail the build on CI if you accidentally left test.only in the source code. */
forbidOnly: !!process.env.CI,
/* Retry on CI only */
retries: process.env.CI ? 1 : 0,
/* use more workers on CI. */
workers: process.env.CI ? 4 : undefined,
retries: process.env.CI ? Number(process.env.PLAYWRIGHT_RETRIES ?? 2) : 0,
/* Higher worker count keeps PR smoke runtime down without sharing page state. */
workers: configuredWorkers,
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
reporter: [
["list"],
@@ -92,40 +105,25 @@ export default defineConfig({
},
},
],
...(jsonReporterOutputFile
? [["json", { outputFile: jsonReporterOutputFile }] as const]
: []),
],
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
use: {
/* Base URL to use in actions like `await page.goto('/')`. */
baseURL: "http://localhost:3000/",
baseURL,
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
screenshot: "only-on-failure",
bypassCSP: true,
/* Helps debugging failures */
trace: "retain-on-failure",
video: "retain-on-failure",
trace: process.env.CI ? "on-first-retry" : "retain-on-failure",
video: process.env.CI ? "off" : "retain-on-failure",
/* Auto-accept cookies in all tests to prevent banner interference */
storageState: {
cookies: [],
origins: [
{
origin: "http://localhost:3000",
localStorage: [
{
name: "autogpt_cookie_consent",
value: JSON.stringify({
hasConsented: true,
timestamp: Date.now(),
analytics: true,
monitoring: true,
}),
},
],
},
],
},
storageState: buildCookieConsentStorageState(baseOrigin),
},
/* Maximum time one test can run for */
timeout: 25000,
@@ -133,7 +131,7 @@ export default defineConfig({
/* Configure web server to start automatically (local dev only) */
webServer: {
command: "pnpm start",
url: "http://localhost:3000",
url: baseURL,
reuseExistingServer: true,
},

View File

@@ -29,6 +29,16 @@ const emptyDashboard: PlatformCostDashboard = {
total_cost_microdollars: 0,
total_requests: 0,
total_users: 0,
total_input_tokens: 0,
total_output_tokens: 0,
avg_input_tokens_per_request: 0,
avg_output_tokens_per_request: 0,
avg_cost_microdollars_per_request: 0,
cost_p50_microdollars: 0,
cost_p75_microdollars: 0,
cost_p95_microdollars: 0,
cost_p99_microdollars: 0,
cost_buckets: [],
by_provider: [],
by_user: [],
};
@@ -47,6 +57,20 @@ const dashboardWithData: PlatformCostDashboard = {
total_cost_microdollars: 5_000_000,
total_requests: 100,
total_users: 5,
total_input_tokens: 150000,
total_output_tokens: 60000,
avg_input_tokens_per_request: 2500,
avg_output_tokens_per_request: 1000,
avg_cost_microdollars_per_request: 83333,
cost_p50_microdollars: 50000,
cost_p75_microdollars: 100000,
cost_p95_microdollars: 250000,
cost_p99_microdollars: 500000,
cost_buckets: [
{ bucket: "$0-0.50", count: 80 },
{ bucket: "$0.50-1", count: 15 },
{ bucket: "$1-2", count: 5 },
],
by_provider: [
{
provider: "openai",
@@ -75,6 +99,7 @@ const dashboardWithData: PlatformCostDashboard = {
total_input_tokens: 50000,
total_output_tokens: 20000,
request_count: 60,
cost_bearing_request_count: 40,
},
],
};
@@ -134,9 +159,14 @@ describe("PlatformCostContent", () => {
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
// Known Cost and Estimated Total cards render $0.0000
// "Known Cost" appears in both the SummaryCard and the ProviderTable header
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
// All cost summary cards (Known Cost, Estimated Total, Avg Cost,
// Typical/Upper/High/Peak Cost) show $0.0000
const zeroCostItems = screen.getAllByText("$0.0000");
expect(zeroCostItems.length).toBe(2);
expect(zeroCostItems.length).toBe(7);
expect(screen.getByText("No cost data yet")).toBeDefined();
});
@@ -155,7 +185,9 @@ describe("PlatformCostContent", () => {
);
expect(screen.getByText("$5.0000")).toBeDefined();
expect(screen.getByText("100")).toBeDefined();
expect(screen.getByText("5")).toBeDefined();
// "5" appears in multiple places (Active Users card + bucket count),
// so verify at least one element renders it.
expect(screen.getAllByText("5").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("openai")).toBeDefined();
expect(screen.getByText("google_maps")).toBeDefined();
});
@@ -223,10 +255,83 @@ describe("PlatformCostContent", () => {
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Original 4 cards
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
expect(screen.getByText("Total Requests")).toBeDefined();
expect(screen.getByText("Active Users")).toBeDefined();
// New average/token cards
expect(screen.getByText("Avg Cost / Request")).toBeDefined();
expect(screen.getByText("Avg Input Tokens")).toBeDefined();
expect(screen.getByText("Avg Output Tokens")).toBeDefined();
expect(screen.getByText("Total Tokens")).toBeDefined();
// Percentile cards (friendlier labels)
expect(screen.getByText("Typical Cost (P50)")).toBeDefined();
expect(screen.getByText("Upper Cost (P75)")).toBeDefined();
expect(screen.getByText("High Cost (P95)")).toBeDefined();
expect(screen.getByText("Peak Cost (P99)")).toBeDefined();
});
it("renders cost distribution buckets", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Cost Distribution by Bucket")).toBeDefined();
expect(screen.getByText("$0-0.50")).toBeDefined();
expect(screen.getByText("$0.50-1")).toBeDefined();
expect(screen.getByText("$1-2")).toBeDefined();
expect(screen.getByText("80")).toBeDefined();
expect(screen.getByText("15")).toBeDefined();
});
it("renders new summary card values from fixture data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Avg Input Tokens: 2500 formatted
expect(screen.getByText("2,500")).toBeDefined();
// Avg Output Tokens: 1000 formatted
expect(screen.getByText("1,000")).toBeDefined();
// P50 cost: 50000 microdollars = $0.0500
expect(screen.getByText("$0.0500")).toBeDefined();
});
it("renders user table avg cost column with fixture data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// User table should show Avg Cost / Req header
expect(screen.getByText("Avg Cost / Req")).toBeDefined();
// Input/Output token columns
expect(screen.getByText("Input Tokens")).toBeDefined();
expect(screen.getByText("Output Tokens")).toBeDefined();
});
it("renders filter inputs", async () => {

View File

@@ -2,12 +2,13 @@
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { formatMicrodollars } from "../helpers";
import { formatMicrodollars, formatTokens } from "../helpers";
import { SummaryCard } from "./SummaryCard";
import { ProviderTable } from "./ProviderTable";
import { UserTable } from "./UserTable";
import { LogsTable } from "./LogsTable";
import { usePlatformCostContent } from "./usePlatformCostContent";
import type { CostBucket } from "@/app/api/__generated__/models/costBucket";
interface Props {
searchParams: {
@@ -54,6 +55,76 @@ export function PlatformCostContent({ searchParams }: Props) {
handleExport,
} = usePlatformCostContent(searchParams);
const summaryCards: { label: string; value: string; subtitle?: string }[] =
dashboard
? [
{
label: "Known Cost",
value: formatMicrodollars(dashboard.total_cost_microdollars),
subtitle: "From providers that report USD cost",
},
{
label: "Estimated Total",
value: formatMicrodollars(totalEstimatedCost),
subtitle: "Including per-run cost estimates",
},
{
label: "Total Requests",
value: dashboard.total_requests.toLocaleString(),
},
{
label: "Active Users",
value: dashboard.total_users.toLocaleString(),
},
{
label: "Avg Cost / Request",
value: formatMicrodollars(
dashboard.avg_cost_microdollars_per_request ?? 0,
),
subtitle: "Known cost divided by cost-bearing requests",
},
{
label: "Avg Input Tokens",
value: Math.round(
dashboard.avg_input_tokens_per_request ?? 0,
).toLocaleString(),
subtitle: "Prompt tokens per request (context size)",
},
{
label: "Avg Output Tokens",
value: Math.round(
dashboard.avg_output_tokens_per_request ?? 0,
).toLocaleString(),
subtitle: "Completion tokens per request (response length)",
},
{
label: "Total Tokens",
value: `${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`,
subtitle: "Prompt vs completion token split",
},
{
label: "Typical Cost (P50)",
value: formatMicrodollars(dashboard.cost_p50_microdollars ?? 0),
subtitle: "Median cost per request",
},
{
label: "Upper Cost (P75)",
value: formatMicrodollars(dashboard.cost_p75_microdollars ?? 0),
subtitle: "75th percentile cost",
},
{
label: "High Cost (P95)",
value: formatMicrodollars(dashboard.cost_p95_microdollars ?? 0),
subtitle: "95th percentile cost",
},
{
label: "Peak Cost (P99)",
value: formatMicrodollars(dashboard.cost_p99_microdollars ?? 0),
subtitle: "99th percentile cost",
},
]
: [];
return (
<div className="flex flex-col gap-6">
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
@@ -204,37 +275,54 @@ export function PlatformCostContent({ searchParams }: Props) {
{loading ? (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
{[...Array(4)].map((_, i) => (
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
{/* 12 skeleton placeholders — one per summary card */}
{Array.from({ length: 12 }, (_, i) => (
<Skeleton key={i} className="h-20 rounded-lg" />
))}
</div>
<Skeleton className="h-32 rounded-lg" />
<Skeleton className="h-8 w-48 rounded" />
<Skeleton className="h-64 rounded-lg" />
</div>
) : (
<>
{dashboard && (
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
<SummaryCard
label="Known Cost"
value={formatMicrodollars(dashboard.total_cost_microdollars)}
subtitle="From providers that report USD cost"
/>
<SummaryCard
label="Estimated Total"
value={formatMicrodollars(totalEstimatedCost)}
subtitle="Including per-run cost estimates"
/>
<SummaryCard
label="Total Requests"
value={dashboard.total_requests.toLocaleString()}
/>
<SummaryCard
label="Active Users"
value={dashboard.total_users.toLocaleString()}
/>
</div>
<>
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
{summaryCards.map((card) => (
<SummaryCard
key={card.label}
label={card.label}
value={card.value}
subtitle={card.subtitle}
/>
))}
</div>
{dashboard.cost_buckets && dashboard.cost_buckets.length > 0 && (
<div className="rounded-lg border p-4">
<h3 className="mb-3 text-sm font-medium">
Cost Distribution by Bucket
</h3>
<div className="grid grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-6">
{dashboard.cost_buckets.map((b: CostBucket) => (
<div
key={b.bucket}
className="flex flex-col items-center rounded border p-2 text-center"
>
<span className="text-xs text-muted-foreground">
{b.bucket}
</span>
<span className="text-lg font-semibold">
{b.count.toLocaleString()}
</span>
</div>
))}
</div>
</div>
)}
</>
)}
<div

View File

@@ -3,6 +3,7 @@ import {
defaultRateFor,
estimateCostForRow,
formatMicrodollars,
formatTokens,
rateKey,
rateUnitLabel,
trackingValue,
@@ -33,6 +34,20 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
<th scope="col" className="px-4 py-3 text-right">
Usage
</th>
<th
scope="col"
className="px-4 py-3 text-right"
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
>
Input Tokens
</th>
<th
scope="col"
className="px-4 py-3 text-right"
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
>
Output Tokens
</th>
<th scope="col" className="px-4 py-3 text-right">
Requests
</th>
@@ -74,6 +89,16 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
<TrackingBadge trackingType={row.tracking_type} />
</td>
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
<td className="px-4 py-3 text-right">
{row.total_input_tokens > 0
? formatTokens(row.total_input_tokens)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{row.total_output_tokens > 0
? formatTokens(row.total_output_tokens)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{row.request_count.toLocaleString()}
</td>
@@ -124,7 +149,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
{data.length === 0 && (
<tr>
<td
colSpan={8}
colSpan={10}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet

View File

@@ -27,10 +27,7 @@ function UserTable({ data }: Props) {
Output Tokens
</th>
<th scope="col" className="px-4 py-3 text-right">
Cache Read
</th>
<th scope="col" className="px-4 py-3 text-right">
Cache Write
Avg Cost / Req
</th>
</tr>
</thead>
@@ -61,13 +58,12 @@ function UserTable({ data }: Props) {
{formatTokens(row.total_output_tokens)}
</td>
<td className="px-4 py-3 text-right">
{(row.total_cache_read_tokens ?? 0) > 0
? formatTokens(row.total_cache_read_tokens ?? 0)
: "-"}
</td>
<td className="px-4 py-3 text-right">
{(row.total_cache_creation_tokens ?? 0) > 0
? formatTokens(row.total_cache_creation_tokens ?? 0)
{(row.cost_bearing_request_count ?? 0) > 0 &&
row.total_cost_microdollars > 0
? formatMicrodollars(
row.total_cost_microdollars /
(row.cost_bearing_request_count ?? 1),
)
: "-"}
</td>
</tr>
@@ -75,7 +71,7 @@ function UserTable({ data }: Props) {
{data.length === 0 && (
<tr>
<td
colSpan={7}
colSpan={6}
className="px-4 py-8 text-center text-muted-foreground"
>
No cost data yet

View File

@@ -0,0 +1,145 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { ArtifactCard } from "./ArtifactCard";
import type { ArtifactRef } from "../../store";
import { useCopilotUIStore } from "../../store";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "report.html",
mimeType: "text/html",
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
origin: "agent",
...overrides,
};
}
const meta: Meta<typeof ArtifactCard> = {
title: "Copilot/ArtifactCard",
component: ArtifactCard,
tags: ["autodocs"],
parameters: {
layout: "padded",
docs: {
description: {
component:
"Inline artifact card rendered in chat messages. Openable artifacts show a caret and open the ArtifactPanel on click. Download-only artifacts trigger a file download.",
},
},
},
decorators: [
(Story) => (
<div className="w-96">
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof meta>;
export const OpenableHTML: Story = {
name: "Openable (HTML)",
args: {
artifact: makeArtifact({
title: "dashboard.html",
mimeType: "text/html",
}),
},
};
export const OpenableImage: Story = {
name: "Openable (Image)",
args: {
artifact: makeArtifact({
id: "img-card",
title: "chart.png",
mimeType: "image/png",
}),
},
};
export const OpenableCode: Story = {
name: "Openable (Code)",
args: {
artifact: makeArtifact({
title: "script.py",
mimeType: "text/x-python",
}),
},
};
export const DownloadOnly: Story = {
name: "Download Only (ZIP)",
args: {
artifact: makeArtifact({
title: "archive.zip",
mimeType: "application/zip",
sizeBytes: 2_500_000,
}),
},
};
export const PreviewableVideo: Story = {
name: "Previewable (Video)",
args: {
artifact: makeArtifact({
title: "demo.mp4",
mimeType: "video/mp4",
sizeBytes: 15_000_000,
}),
},
parameters: {
docs: {
description: {
story:
"Videos with supported formats (MP4, WebM, M4V) are previewable inline in the artifact panel.",
},
},
},
};
export const WithSize: Story = {
name: "With File Size",
args: {
artifact: makeArtifact({
title: "data.csv",
mimeType: "text/csv",
sizeBytes: 524_288,
}),
},
};
export const UserUpload: Story = {
name: "User Upload Origin",
args: {
artifact: makeArtifact({
title: "requirements.txt",
mimeType: "text/plain",
origin: "user-upload",
}),
},
};
export const ActiveState: Story = {
name: "Active (Panel Open)",
args: {
artifact: makeArtifact({ id: "active-card" }),
},
decorators: [
(Story) => {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: makeArtifact({ id: "active-card" }),
history: [],
},
});
return <Story />;
},
],
};

View File

@@ -0,0 +1,223 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { http, HttpResponse } from "msw";
import { ArtifactPanel } from "./ArtifactPanel";
import { useCopilotUIStore } from "../../store";
import type { ArtifactRef } from "../../store";
const PROXY_BASE = "/api/proxy/api/workspace/files";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "report.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/file-001/download`,
origin: "agent",
...overrides,
};
}
function openPanelWith(artifact: ArtifactRef) {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: artifact,
history: [],
},
});
}
const meta: Meta<typeof ArtifactPanel> = {
title: "Copilot/ArtifactPanel",
component: ArtifactPanel,
tags: ["autodocs"],
parameters: {
layout: "fullscreen",
docs: {
description: {
component:
"Side panel for previewing workspace artifacts. Supports resize, minimize, maximize, and navigation history. Bug: panel auto-opens on chat switch instead of staying collapsed.",
},
},
},
decorators: [
(Story) => (
<div className="flex h-[600px] w-full">
<div className="flex-1 bg-zinc-50 p-8">
<p className="text-sm text-zinc-500">Chat area</p>
</div>
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof meta>;
export const OpenWithTextArtifact: Story = {
name: "Open — Text File",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({ title: "notes.txt", mimeType: "text/plain" }),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/file-001/download`, () => {
return HttpResponse.text(
"These are some notes from the agent execution.\n\nKey findings:\n1. Performance improved by 23%\n2. Memory usage reduced\n3. Error rate dropped to 0.1%",
);
}),
],
},
},
};
export const OpenWithHTMLArtifact: Story = {
name: "Open — HTML",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({
id: "html-panel",
title: "dashboard.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/html-panel/download`,
}),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/html-panel/download`, () => {
return HttpResponse.text(
`<!DOCTYPE html><html><body class="p-8 font-sans"><h1 class="text-2xl font-bold text-indigo-600">Dashboard</h1><p class="mt-2 text-gray-600">HTML artifact in the panel.</p></body></html>`,
);
}),
],
},
},
};
export const OpenWithImageArtifact: Story = {
name: "Open — Image (Bug: No Loading State)",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({
id: "img-panel",
title: "chart.png",
mimeType: "image/png",
sourceUrl: `${PROXY_BASE}/img-panel/download`,
}),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/img-panel/download`, () => {
return HttpResponse.text(
'<svg xmlns="http://www.w3.org/2000/svg" width="500" height="300"><rect width="500" height="300" fill="#dbeafe"/><text x="250" y="150" text-anchor="middle" fill="#1e40af" font-size="20">Image Preview (no skeleton)</text></svg>',
{ headers: { "Content-Type": "image/svg+xml" } },
);
}),
],
},
docs: {
description: {
story:
"**BUG:** Image artifacts render with a bare `<img>` tag — no loading skeleton or error handling. Compare with text/HTML artifacts which show a proper skeleton while loading.",
},
},
},
};
export const MinimizedStrip: Story = {
name: "Minimized",
decorators: [
(Story) => {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: true,
isMinimized: true,
isMaximized: false,
width: 600,
activeArtifact: makeArtifact(),
history: [],
},
});
return <Story />;
},
],
};
export const ErrorState: Story = {
name: "Error — Failed to Load (Stale Artifact)",
decorators: [
(Story) => {
openPanelWith(
makeArtifact({
id: "stale-panel",
title: "old-report.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/stale-panel/download`,
}),
);
return <Story />;
},
],
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/stale-panel/download`, () => {
return new HttpResponse(null, { status: 404 });
}),
],
},
docs: {
description: {
story:
"Shows what users see when opening a previously generated artifact that no longer exists on the backend (404). The 'Try again' button retries the fetch.",
},
},
},
};
export const Closed: Story = {
name: "Closed (Default State)",
decorators: [
(Story) => {
useCopilotUIStore.setState({
artifactPanel: {
isOpen: false,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: null,
history: [],
},
});
return <Story />;
},
],
parameters: {
docs: {
description: {
story:
"The default state — panel is closed. It should only open when a user clicks on an artifact card in the chat.",
},
},
},
};

View File

@@ -0,0 +1,413 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { downloadArtifact } from "../downloadArtifact";
import type { ArtifactRef } from "../../../store";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "report.pdf",
mimeType: "application/pdf",
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
origin: "agent",
...overrides,
};
}
describe("downloadArtifact", () => {
let clickSpy: ReturnType<typeof vi.fn>;
let removeSpy: ReturnType<typeof vi.fn>;
beforeEach(() => {
clickSpy = vi.fn();
removeSpy = vi.fn();
vi.stubGlobal(
"URL",
Object.assign(URL, {
createObjectURL: vi.fn().mockReturnValue("blob:fake-url"),
revokeObjectURL: vi.fn(),
}),
);
vi.spyOn(document, "createElement").mockReturnValue({
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
} as unknown as HTMLAnchorElement);
vi.spyOn(document.body, "appendChild").mockImplementation(
(node) => node as ChildNode,
);
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it("downloads file successfully on 200 response", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["pdf content"])),
}),
);
await downloadArtifact(makeArtifact());
expect(fetch).toHaveBeenCalledWith(
"/api/proxy/api/workspace/files/file-001/download",
);
expect(clickSpy).toHaveBeenCalled();
expect(removeSpy).toHaveBeenCalled();
expect(URL.revokeObjectURL).toHaveBeenCalledWith("blob:fake-url");
});
it("rejects on persistent server error after exhausting retries", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 500,
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 500",
);
expect(clickSpy).not.toHaveBeenCalled();
});
it("rejects on persistent network error after exhausting retries", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.reject(new Error("Network error"));
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Network error",
);
expect(callCount).toBe(3);
expect(clickSpy).not.toHaveBeenCalled();
});
it("retries on transient network error and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.reject(new Error("Connection reset"));
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
it("retries on transient 500 and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({ ok: false, status: 500 });
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
// Should succeed on second attempt
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
it("sanitizes dangerous filenames", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: "../../../etc/passwd" }));
expect(anchor.download).not.toContain("..");
expect(anchor.download).not.toContain("/");
});
// ── Transient retry codes ─────────────────────────────────────────
it("retries on 408 (Request Timeout) and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({ ok: false, status: 408 });
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
it("retries on 429 (Too Many Requests) and succeeds", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({ ok: false, status: 429 });
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
});
}),
);
await downloadArtifact(makeArtifact());
expect(callCount).toBe(2);
expect(clickSpy).toHaveBeenCalled();
});
// ── Non-transient errors ──────────────────────────────────────────
it("rejects immediately on 403 (non-transient) without retry", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({ ok: false, status: 403 });
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 403",
);
expect(callCount).toBe(1);
expect(clickSpy).not.toHaveBeenCalled();
});
it("rejects immediately on 404 without retry", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({ ok: false, status: 404 });
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 404",
);
expect(callCount).toBe(1);
});
// ── Exhausted retries ─────────────────────────────────────────────
it("rejects after exhausting all retries on persistent 500", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({ ok: false, status: 500 });
}),
);
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
"Download failed: 500",
);
// Initial attempt + 2 retries = 3 total
expect(callCount).toBe(3);
expect(clickSpy).not.toHaveBeenCalled();
});
// ── Filename edge cases ───────────────────────────────────────────
it("falls back to 'download' when title is empty", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: "" }));
expect(anchor.download).toBe("download");
});
it("falls back to 'download' when title is only dots", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
// Dot-only names should not produce a hidden or empty filename.
await downloadArtifact(makeArtifact({ title: "...." }));
expect(anchor.download).toBe("download");
});
it("replaces special chars with underscores (not empty)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: '***???"' }));
// Special chars become underscores, not removed
expect(anchor.download).toBe("_______");
});
it("strips leading dots from filename", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(makeArtifact({ title: "...hidden.txt" }));
expect(anchor.download).not.toMatch(/^\./);
expect(anchor.download).toContain("hidden.txt");
});
it("replaces Windows-reserved characters", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(
makeArtifact({ title: "file<name>with:bad*chars?.txt" }),
);
expect(anchor.download).not.toMatch(/[<>:*?]/);
});
it("replaces control characters in filename", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
const anchor = {
href: "",
download: "",
click: clickSpy,
remove: removeSpy,
};
vi.spyOn(document, "createElement").mockReturnValue(
anchor as unknown as HTMLAnchorElement,
);
await downloadArtifact(
makeArtifact({ title: "file\x00with\x1fcontrol.txt" }),
);
expect(anchor.download).not.toMatch(/[\x00-\x1f]/);
});
});

View File

@@ -0,0 +1,460 @@
import type { Meta, StoryObj } from "@storybook/nextjs";
import { http, HttpResponse } from "msw";
import { ArtifactContent } from "./ArtifactContent";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import {
Code,
File,
FileHtml,
FileText,
Image,
Table,
} from "@phosphor-icons/react";
const PROXY_BASE = "/api/proxy/api/workspace/files";
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "test.txt",
mimeType: "text/plain",
sourceUrl: `${PROXY_BASE}/file-001/download`,
origin: "agent",
...overrides,
};
}
function makeClassification(
overrides?: Partial<ArtifactClassification>,
): ArtifactClassification {
return {
type: "text",
icon: FileText,
label: "Text",
openable: true,
hasSourceToggle: false,
...overrides,
};
}
const meta: Meta<typeof ArtifactContent> = {
title: "Copilot/ArtifactContent",
component: ArtifactContent,
tags: ["autodocs"],
parameters: {
layout: "padded",
docs: {
description: {
component:
"Renders artifact content based on file type classification. Supports images, HTML, code, CSV, JSON, markdown, PDF, and plain text. Bug: image artifacts render as bare <img> with no loading/error states.",
},
},
},
decorators: [
(Story) => (
<div
className="flex h-[500px] w-[600px] flex-col overflow-hidden border border-zinc-200"
style={{ resize: "both" }}
>
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof meta>;
export const ImageArtifactPNG: Story = {
name: "Image (PNG) — No Loading Skeleton (Bug #1)",
args: {
artifact: makeArtifact({
id: "img-png",
title: "chart.png",
mimeType: "image/png",
sourceUrl: `${PROXY_BASE}/img-png/download`,
}),
isSourceView: false,
classification: makeClassification({ type: "image", icon: Image }),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/img-png/download`, () => {
return HttpResponse.text(
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#e0e7ff"/><text x="200" y="150" text-anchor="middle" fill="#4338ca" font-size="24">PNG Placeholder</text></svg>',
{ headers: { "Content-Type": "image/svg+xml" } },
);
}),
],
},
docs: {
description: {
story:
"**BUG:** This renders a bare `<img>` tag with no loading skeleton or error handling. Compare with WorkspaceFileRenderer which has proper Skeleton + onError states.",
},
},
},
};
export const ImageArtifactSVG: Story = {
name: "Image (SVG)",
args: {
artifact: makeArtifact({
id: "img-svg",
title: "diagram.svg",
mimeType: "image/svg+xml",
sourceUrl: `${PROXY_BASE}/img-svg/download`,
}),
isSourceView: false,
classification: makeClassification({ type: "image", icon: Image }),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/img-svg/download`, () => {
return HttpResponse.text(
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#fef3c7"/><circle cx="200" cy="150" r="80" fill="#f59e0b"/><text x="200" y="155" text-anchor="middle" fill="white" font-size="20">SVG OK</text></svg>',
{ headers: { "Content-Type": "image/svg+xml" } },
);
}),
],
},
},
};
export const HTMLArtifact: Story = {
name: "HTML",
args: {
artifact: makeArtifact({
id: "html-001",
title: "page.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/html-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "html",
icon: FileHtml,
label: "HTML",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/html-001/download`, () => {
return HttpResponse.text(
`<!DOCTYPE html>
<html>
<head><title>Artifact Preview</title></head>
<body class="p-8 font-sans">
<h1 class="text-2xl font-bold text-indigo-600 mb-4">HTML Artifact</h1>
<p class="text-gray-700">This is an HTML artifact rendered in a sandboxed iframe with Tailwind CSS injected.</p>
<div class="mt-4 p-4 bg-blue-50 rounded-lg border border-blue-200">
<p class="text-blue-800">Interactive content works via allow-scripts sandbox.</p>
</div>
</body>
</html>`,
{ headers: { "Content-Type": "text/html" } },
);
}),
],
},
},
};
export const CodeArtifact: Story = {
name: "Code (Python)",
args: {
artifact: makeArtifact({
id: "code-001",
title: "analysis.py",
mimeType: "text/x-python",
sourceUrl: `${PROXY_BASE}/code-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "code",
icon: Code,
label: "Code",
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/code-001/download`, () => {
return HttpResponse.text(
`import pandas as pd
import matplotlib.pyplot as plt
def analyze_data(filepath: str) -> pd.DataFrame:
"""Load and analyze CSV data."""
df = pd.read_csv(filepath)
summary = df.describe()
print(f"Loaded {len(df)} rows")
return summary
if __name__ == "__main__":
result = analyze_data("data.csv")
print(result)`,
{ headers: { "Content-Type": "text/plain" } },
);
}),
],
},
},
};
export const CSVArtifact: Story = {
name: "CSV (Spreadsheet)",
args: {
artifact: makeArtifact({
id: "csv-001",
title: "data.csv",
mimeType: "text/csv",
sourceUrl: `${PROXY_BASE}/csv-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "csv",
icon: Table,
label: "Spreadsheet",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/csv-001/download`, () => {
return HttpResponse.text(
`Name,Age,City,Score
Alice,28,New York,92
Bob,35,San Francisco,87
Charlie,22,Chicago,95
Diana,31,Boston,88
Eve,27,Seattle,91`,
{ headers: { "Content-Type": "text/csv" } },
);
}),
],
},
},
};
export const JSONArtifact: Story = {
name: "JSON (Data)",
args: {
artifact: makeArtifact({
id: "json-001",
title: "config.json",
mimeType: "application/json",
sourceUrl: `${PROXY_BASE}/json-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "json",
icon: Code,
label: "Data",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/json-001/download`, () => {
return HttpResponse.text(
JSON.stringify(
{
name: "AutoGPT Agent",
version: "2.0",
capabilities: ["web_search", "code_execution", "file_io"],
settings: { maxTokens: 4096, temperature: 0.7 },
},
null,
2,
),
{ headers: { "Content-Type": "application/json" } },
);
}),
],
},
},
};
export const MarkdownArtifact: Story = {
name: "Markdown",
args: {
artifact: makeArtifact({
id: "md-001",
title: "README.md",
mimeType: "text/markdown",
sourceUrl: `${PROXY_BASE}/md-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "markdown",
icon: FileText,
label: "Document",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/md-001/download`, () => {
return HttpResponse.text(
`# Project Summary
## Overview
This is a **markdown** artifact rendered through the global renderer registry.
## Features
- Headings and paragraphs
- **Bold** and *italic* text
- Lists and code blocks
\`\`\`python
print("Hello from markdown!")
\`\`\`
> Blockquotes are also supported.`,
{ headers: { "Content-Type": "text/plain" } },
);
}),
],
},
},
};
export const PDFArtifact: Story = {
name: "PDF",
args: {
artifact: makeArtifact({
id: "pdf-001",
title: "report.pdf",
mimeType: "application/pdf",
sourceUrl: `${PROXY_BASE}/pdf-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "pdf",
icon: FileText,
label: "PDF",
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/pdf-001/download`, () => {
return HttpResponse.arrayBuffer(new ArrayBuffer(100), {
headers: { "Content-Type": "application/pdf" },
});
}),
],
},
docs: {
description: {
story:
"PDF artifacts are rendered in an unsandboxed iframe using a blob URL (Chromium bug #413851 prevents sandboxed PDF rendering).",
},
},
},
};
export const ErrorState: Story = {
name: "Error — Failed to Load Content",
args: {
artifact: makeArtifact({
id: "error-001",
title: "old-report.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/error-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "html",
icon: FileHtml,
label: "HTML",
hasSourceToggle: true,
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/error-001/download`, () => {
return new HttpResponse(null, { status: 404 });
}),
],
},
docs: {
description: {
story:
"Shows the error state when an artifact fails to load (e.g., old/expired file returning 404). Includes a 'Try again' retry button.",
},
},
},
};
export const LoadingSkeleton: Story = {
name: "Loading State",
args: {
artifact: makeArtifact({
id: "loading-001",
title: "loading.html",
mimeType: "text/html",
sourceUrl: `${PROXY_BASE}/loading-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "html",
icon: FileHtml,
label: "HTML",
}),
},
parameters: {
msw: {
handlers: [
http.get(`${PROXY_BASE}/loading-001/download`, async () => {
// Delay response to show loading state
await new Promise((r) => setTimeout(r, 999999));
return HttpResponse.text("never resolves");
}),
],
},
docs: {
description: {
story:
"Shows the skeleton loading state while content is being fetched.",
},
},
},
};
export const DownloadOnly: Story = {
name: "Download Only (Binary)",
args: {
artifact: makeArtifact({
id: "bin-001",
title: "archive.zip",
mimeType: "application/zip",
sourceUrl: `${PROXY_BASE}/bin-001/download`,
}),
isSourceView: false,
classification: makeClassification({
type: "download-only",
icon: File,
label: "File",
openable: false,
}),
},
parameters: {
docs: {
description: {
story:
"Download-only files (binary, video, etc.) are not rendered inline. The ArtifactPanel shows nothing for these — they are handled by ArtifactCard with a download button.",
},
},
},
};

View File

@@ -2,7 +2,8 @@
import { globalRegistry } from "@/components/contextual/OutputRenderers";
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
import { Suspense } from "react";
import { Suspense, useState } from "react";
import { Skeleton } from "@/components/ui/skeleton";
import type { ArtifactRef } from "../../../store";
import type { ArtifactClassification } from "../helpers";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
@@ -63,6 +64,90 @@ function ArtifactContentLoader({
);
}
function ArtifactImage({ src, alt }: { src: string; alt: string }) {
const [loaded, setLoaded] = useState(false);
const [error, setError] = useState(false);
if (error) {
return (
<div
role="alert"
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm text-zinc-500">Failed to load image</p>
<button
type="button"
onClick={() => {
setError(false);
setLoaded(false);
}}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Try again
</button>
</div>
);
}
return (
<div className="relative flex items-center justify-center p-4">
{!loaded && (
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
)}
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={src}
alt={alt}
className={`max-h-full max-w-full object-contain transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
onLoad={() => setLoaded(true)}
onError={() => setError(true)}
/>
</div>
);
}
function ArtifactVideo({ src }: { src: string }) {
const [loaded, setLoaded] = useState(false);
const [error, setError] = useState(false);
if (error) {
return (
<div
role="alert"
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
>
<p className="text-sm text-zinc-500">Failed to load video</p>
<button
type="button"
onClick={() => {
setError(false);
setLoaded(false);
}}
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
>
Try again
</button>
</div>
);
}
return (
<div className="relative flex items-center justify-center p-4">
{!loaded && (
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
)}
<video
src={src}
controls
preload="metadata"
className={`max-h-full max-w-full rounded-md transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
onLoadedMetadata={() => setLoaded(true)}
onError={() => setError(true)}
/>
</div>
);
}
function ArtifactRenderer({
artifact,
content,
@@ -79,17 +164,19 @@ function ArtifactRenderer({
// Image: render directly from URL (no content fetch)
if (classification.type === "image") {
return (
<div className="flex items-center justify-center p-4">
{/* eslint-disable-next-line @next/next/no-img-element */}
<img
src={artifact.sourceUrl}
alt={artifact.title}
className="max-h-full max-w-full object-contain"
/>
</div>
<ArtifactImage
key={artifact.sourceUrl}
src={artifact.sourceUrl}
alt={artifact.title}
/>
);
}
// Video: render with <video> controls (no content fetch)
if (classification.type === "video") {
return <ArtifactVideo key={artifact.sourceUrl} src={artifact.sourceUrl} />;
}
if (classification.type === "pdf" && pdfUrl) {
// No sandbox — Chrome/Edge block PDF rendering in sandboxed iframes
// (Chromium bug #413851). The blob URL has a null origin so it can't
@@ -164,7 +251,16 @@ function ArtifactRenderer({
// CSV: pass with explicit metadata so CSVRenderer matches
if (classification.type === "csv") {
const csvMeta = { mimeType: "text/csv", filename: artifact.title };
const normalizedMime = artifact.mimeType
?.toLowerCase()
.split(";")[0]
?.trim();
const csvMimeType =
normalizedMime === "text/tab-separated-values" ||
artifact.title.toLowerCase().endsWith(".tsv")
? "text/tab-separated-values"
: "text/csv";
const csvMeta = { mimeType: csvMimeType, filename: artifact.title };
const csvRenderer = globalRegistry.getRenderer(content, csvMeta);
if (csvRenderer) {
return <div className="p-4">{csvRenderer.render(content, csvMeta)}</div>;

View File

@@ -0,0 +1,67 @@
import { render, screen, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { ArtifactReactPreview } from "./ArtifactReactPreview";
import {
buildReactArtifactSrcDoc,
collectPreviewStyles,
transpileReactArtifactSource,
} from "./reactArtifactPreview";
vi.mock("./reactArtifactPreview", () => ({
buildReactArtifactSrcDoc: vi.fn(),
collectPreviewStyles: vi.fn(),
transpileReactArtifactSource: vi.fn(),
}));
describe("ArtifactReactPreview", () => {
beforeEach(() => {
vi.mocked(collectPreviewStyles).mockReturnValue("<style>preview</style>");
vi.mocked(buildReactArtifactSrcDoc).mockReturnValue("<html>preview</html>");
});
it("renders an iframe preview after transpilation succeeds", async () => {
vi.mocked(transpileReactArtifactSource).mockResolvedValue(
"module.exports.default = function Artifact() { return null; };",
);
const { container } = render(
<ArtifactReactPreview
source="export default function Artifact() { return null; }"
title="Artifact.tsx"
/>,
);
await waitFor(() => {
expect(buildReactArtifactSrcDoc).toHaveBeenCalledWith(
"module.exports.default = function Artifact() { return null; };",
"Artifact.tsx",
"<style>preview</style>",
);
});
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
expect(iframe?.getAttribute("title")).toBe("Artifact.tsx preview");
expect(iframe?.getAttribute("srcdoc")).toBe("<html>preview</html>");
});
it("shows a readable error when transpilation fails", async () => {
vi.mocked(transpileReactArtifactSource).mockRejectedValue(
new Error("Transpile exploded"),
);
render(
<ArtifactReactPreview
source="export default function Artifact() {"
title="Broken.tsx"
/>,
);
await waitFor(() => {
expect(screen.getByText("Failed to render React preview")).toBeTruthy();
});
expect(screen.getByText("Transpile exploded")).toBeTruthy();
});
});

View File

@@ -0,0 +1,970 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import {
cleanup,
fireEvent,
render,
screen,
waitFor,
} from "@testing-library/react";
import { ArtifactContent } from "../ArtifactContent";
import type { ArtifactRef } from "../../../../store";
import { classifyArtifact, type ArtifactClassification } from "../../helpers";
import { globalRegistry } from "@/components/contextual/OutputRenderers";
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
import { ArtifactReactPreview } from "../ArtifactReactPreview";
// Mock the renderers so we don't pull in the full renderer dependency tree
vi.mock("@/components/contextual/OutputRenderers", () => ({
globalRegistry: {
getRenderer: vi.fn().mockReturnValue({
render: vi.fn((_val: unknown, meta: Record<string, unknown>) => (
<div data-testid="global-renderer">
rendered:{String(meta?.mimeType ?? "unknown")}
</div>
)),
}),
},
}));
vi.mock(
"@/components/contextual/OutputRenderers/renderers/CodeRenderer",
() => ({
codeRenderer: {
render: vi.fn((content: string) => (
<div data-testid="code-renderer">{content}</div>
)),
},
}),
);
vi.mock("../ArtifactReactPreview", () => ({
ArtifactReactPreview: vi.fn(
({ source, title }: { source: string; title: string }) => (
<div data-testid="react-preview" data-title={title}>
{source}
</div>
),
),
}));
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
return {
id: "file-001",
title: "test.txt",
mimeType: "text/plain",
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
origin: "agent",
...overrides,
};
}
function makeClassification(
overrides?: Partial<ArtifactClassification>,
): ArtifactClassification {
return {
type: "text",
icon: vi.fn(() => null) as unknown as ArtifactClassification["icon"],
label: "Text",
openable: true,
hasSourceToggle: false,
...overrides,
};
}
describe("ArtifactContent", () => {
beforeEach(() => {
vi.clearAllMocks();
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("file content here"),
blob: () => Promise.resolve(new Blob(["content"])),
}),
);
});
afterEach(() => {
cleanup();
vi.unstubAllGlobals();
});
// ── Image ─────────────────────────────────────────────────────────
it("renders image artifact as img tag with loading skeleton", () => {
const artifact = makeArtifact({
id: "img-001",
title: "photo.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-001/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
expect(img).toBeTruthy();
expect(img?.getAttribute("src")).toBe(
"/api/proxy/api/workspace/files/img-001/download",
);
expect(fetch).not.toHaveBeenCalled();
});
it("image artifact shows loading skeleton before image loads", () => {
const artifact = makeArtifact({
id: "img-skeleton",
title: "photo.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-skeleton/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
// Skeleton uses animate-pulse class
const skeleton = container.querySelector('[class*="animate-pulse"]');
expect(skeleton).toBeTruthy();
});
it("image artifact shows error state when image fails to load", () => {
const artifact = makeArtifact({
id: "img-error",
title: "broken.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-error/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
expect(img).toBeTruthy();
fireEvent.error(img!);
const errorAlert = screen.queryByRole("alert");
expect(errorAlert).toBeTruthy();
expect(screen.queryByText("Failed to load image")).toBeTruthy();
});
it("image retry resets error and re-shows img", async () => {
const artifact = makeArtifact({
id: "img-retry",
title: "retry.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/img-retry/download",
});
const classification = makeClassification({ type: "image" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const img = container.querySelector("img");
fireEvent.error(img!);
// Should show error state
await waitFor(() => {
expect(screen.queryByText("Failed to load image")).toBeTruthy();
});
// Click "Try again"
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
// Error should be cleared, img should reappear
await waitFor(() => {
expect(screen.queryByText("Failed to load image")).toBeNull();
expect(container.querySelector("img")).toBeTruthy();
});
});
// ── Video ─────────────────────────────────────────────────────────
it("renders video artifact with video tag and controls", () => {
const artifact = makeArtifact({
id: "vid-001",
title: "clip.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-001/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const video = container.querySelector("video");
expect(video).toBeTruthy();
expect(video?.hasAttribute("controls")).toBe(true);
expect(video?.getAttribute("src")).toBe(
"/api/proxy/api/workspace/files/vid-001/download",
);
expect(fetch).not.toHaveBeenCalled();
});
it("video shows loading skeleton before metadata loads", () => {
const artifact = makeArtifact({
id: "vid-skel",
title: "clip.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-skel/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const skeleton = container.querySelector('[class*="animate-pulse"]');
expect(skeleton).toBeTruthy();
// After metadata loads, skeleton should disappear
const video = container.querySelector("video");
fireEvent.loadedMetadata(video!);
expect(container.querySelector('[class*="animate-pulse"]')).toBeNull();
});
it("video shows error state when video fails to load", () => {
const artifact = makeArtifact({
id: "vid-error",
title: "broken.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-error/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const video = container.querySelector("video");
expect(video).toBeTruthy();
fireEvent.error(video!);
const errorAlert = screen.queryByRole("alert");
expect(errorAlert).toBeTruthy();
expect(screen.queryByText("Failed to load video")).toBeTruthy();
});
it("video retry resets error and re-shows video", async () => {
const artifact = makeArtifact({
id: "vid-retry",
title: "retry.mp4",
mimeType: "video/mp4",
sourceUrl: "/api/proxy/api/workspace/files/vid-retry/download",
});
const classification = makeClassification({ type: "video" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const video = container.querySelector("video");
fireEvent.error(video!);
await waitFor(() => {
expect(screen.queryByText("Failed to load video")).toBeTruthy();
});
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
await waitFor(() => {
expect(screen.queryByText("Failed to load video")).toBeNull();
expect(container.querySelector("video")).toBeTruthy();
});
});
// ── PDF ───────────────────────────────────────────────────────────
it("renders PDF artifact in unsandboxed iframe with blob URL", async () => {
const blobUrl = "blob:http://localhost/fake-pdf-blob";
vi.stubGlobal(
"URL",
Object.assign(URL, {
createObjectURL: vi.fn().mockReturnValue(blobUrl),
revokeObjectURL: vi.fn(),
}),
);
const artifact = makeArtifact({
id: "pdf-render",
title: "report.pdf",
mimeType: "application/pdf",
sourceUrl: "/api/proxy/api/workspace/files/pdf-render/download",
});
const classification = makeClassification({ type: "pdf" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("src")).toBe(blobUrl);
// No sandbox attribute — Chrome blocks PDF in sandboxed iframes
expect(iframe?.hasAttribute("sandbox")).toBe(false);
});
});
// ── Fetch error ───────────────────────────────────────────────────
it("shows error state with retry button on fetch failure", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
}),
);
const artifact = makeArtifact({ id: "error-content-test" });
const classification = makeClassification({ type: "html" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const errorText = await screen.findByText("Failed to load content");
expect(errorText).toBeTruthy();
const retryButtons = screen.getAllByRole("button", { name: /try again/i });
expect(retryButtons.length).toBeGreaterThan(0);
});
// ── HTML ──────────────────────────────────────────────────────────
it("renders HTML content in sandboxed iframe", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () =>
Promise.resolve("<html><body><h1>Hello World</h1></body></html>"),
}),
);
const artifact = makeArtifact({
id: "html-001",
title: "page.html",
mimeType: "text/html",
});
const classification = makeClassification({ type: "html" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByTitle("page.html");
const iframe = container.querySelector("iframe");
expect(iframe).toBeTruthy();
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
});
// ── Source view ───────────────────────────────────────────────────
it("renders source view as pre tag", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("source code here"),
}),
);
const artifact = makeArtifact({ id: "source-view-test" });
const classification = makeClassification({
type: "html",
hasSourceToggle: true,
});
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={true}
classification={classification}
/>,
);
await screen.findByText("source code here");
const pre = container.querySelector("pre");
expect(pre).toBeTruthy();
expect(pre?.textContent).toBe("source code here");
});
// ── React ─────────────────────────────────────────────────────────
it("renders react artifacts via ArtifactReactPreview", async () => {
const jsxSource = "export default function App() { return <div>Hi</div>; }";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(jsxSource),
}),
);
const artifact = makeArtifact({
id: "react-001",
title: "App.tsx",
mimeType: "text/tsx",
});
const classification = makeClassification({ type: "react" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const preview = await screen.findByTestId("react-preview");
expect(preview).toBeTruthy();
expect(preview.textContent).toContain(jsxSource);
expect(preview.getAttribute("data-title")).toBe("App.tsx");
});
it("routes a concrete props-based TSX artifact into ArtifactReactPreview", async () => {
const jsxSource = `
import React, { FC, useState } from "react";
interface ArtifactFile {
id: string;
name: string;
mimeType: string;
url: string;
sizeBytes: number;
}
interface Props {
files: ArtifactFile[];
onSelect: (file: ArtifactFile) => void;
}
export const previewProps: Props = {
files: [
{
id: "1",
name: "report.png",
mimeType: "image/png",
url: "/report.png",
sizeBytes: 2048,
},
],
onSelect: () => {},
};
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
const [selected, setSelected] = useState<string | null>(null);
const handleClick = (file: ArtifactFile) => {
setSelected(file.id);
onSelect(file);
};
return (
<ul>
{files.map((file) => (
<li key={file.id} onClick={() => handleClick(file)}>
<span>{selected === file.id ? "selected" : file.name}</span>
</li>
))}
</ul>
);
};
export default ArtifactList;
`;
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(jsxSource),
}),
);
const artifact = makeArtifact({
id: "react-props-001",
title: "ArtifactList.tsx",
mimeType: "text/tsx",
});
const classification = classifyArtifact(artifact.mimeType, artifact.title);
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const preview = await screen.findByTestId("react-preview");
expect(preview.textContent).toContain("previewProps");
expect(preview.getAttribute("data-title")).toBe("ArtifactList.tsx");
expect(vi.mocked(ArtifactReactPreview).mock.calls[0]?.[0]).toEqual(
expect.objectContaining({
source: expect.stringContaining("export const previewProps"),
title: "ArtifactList.tsx",
}),
);
});
// ── Code ──────────────────────────────────────────────────────────
it("renders code artifacts via codeRenderer", async () => {
const code = 'def hello():\n print("hi")';
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(code),
}),
);
const artifact = makeArtifact({
id: "code-render-001",
title: "script.py",
mimeType: "text/x-python",
});
const classification = makeClassification({ type: "code" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("code-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain(code);
});
it.each([
{
filename: "events.jsonl",
mimeType: "application/x-ndjson",
content: '{"event":"start"}\n{"event":"finish"}',
},
{
filename: ".env.local",
mimeType: "text/plain",
content: "OPENAI_API_KEY=test\nDEBUG=true",
},
{
filename: "Dockerfile",
mimeType: "text/plain",
content: "FROM node:20\nRUN pnpm install",
},
{
filename: "schema.graphql",
mimeType: "text/plain",
content: "type Query { viewer: User }",
},
])(
"renders concrete code artifact $filename through codeRenderer",
async ({ filename, mimeType, content }) => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(content),
}),
);
const artifact = makeArtifact({
id: `code-${filename}`,
title: filename,
mimeType,
});
const classification = classifyArtifact(
artifact.mimeType,
artifact.title,
);
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByTestId("code-renderer");
expect(classification.type).toBe("code");
expect(vi.mocked(codeRenderer.render)).toHaveBeenCalledWith(
content,
expect.objectContaining({
filename,
mimeType,
type: "code",
}),
);
},
);
// ── JSON ──────────────────────────────────────────────────────────
it("renders valid JSON via globalRegistry", async () => {
const jsonContent = JSON.stringify({ key: "value" }, null, 2);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(jsonContent),
}),
);
const artifact = makeArtifact({
id: "json-render-001",
title: "data.json",
mimeType: "application/json",
});
const classification = makeClassification({ type: "json" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("application/json");
});
it("renders invalid JSON as fallback pre tag", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"
);
const originalImpl = vi
.mocked(globalRegistry.getRenderer)
.getMockImplementation();
// For invalid JSON, JSON.parse throws, then the registry fallback
// also returns null → falls through to <pre>
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("{invalid json!!!"),
}),
);
const artifact = makeArtifact({
id: "json-invalid-001",
title: "bad.json",
mimeType: "application/json",
});
const classification = makeClassification({ type: "json" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
const pre = container.querySelector("pre");
expect(pre).toBeTruthy();
expect(pre?.textContent).toBe("{invalid json!!!");
});
// Restore
if (originalImpl) {
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
}
});
// ── CSV ───────────────────────────────────────────────────────────
it("renders CSV via globalRegistry with text/csv metadata", async () => {
const csvContent = "Name,Age\nAlice,30\nBob,25";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(csvContent),
}),
);
const artifact = makeArtifact({
id: "csv-render-001",
title: "data.csv",
mimeType: "text/csv",
});
const classification = makeClassification({
type: "csv",
hasSourceToggle: true,
});
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("text/csv");
});
it("renders TSV via globalRegistry with tab-separated metadata", async () => {
const tsvContent = "Name\tAge\nAlice\t30\nBob\t25";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(tsvContent),
}),
);
const artifact = makeArtifact({
id: "tsv-render-001",
title: "data.tsv",
mimeType: "text/tab-separated-values",
});
const classification = makeClassification({
type: "csv",
hasSourceToggle: true,
});
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("text/tab-separated-values");
});
// ── Markdown ──────────────────────────────────────────────────────
it("renders markdown via globalRegistry", async () => {
const mdContent = "# Hello\n\nThis is **markdown**.";
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(mdContent),
}),
);
const artifact = makeArtifact({
id: "md-render-001",
title: "README.md",
mimeType: "text/markdown",
});
const classification = makeClassification({
type: "markdown",
hasSourceToggle: true,
});
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
expect(rendered.textContent).toContain("text/markdown");
});
// ── Text fallback ─────────────────────────────────────────────────
it("renders text artifacts via globalRegistry fallback", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("plain text content"),
}),
);
const artifact = makeArtifact({
id: "text-render-001",
title: "notes.txt",
mimeType: "text/plain",
});
const classification = makeClassification({ type: "text" });
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
const rendered = await screen.findByTestId("global-renderer");
expect(rendered).toBeTruthy();
});
it.each([
{
filename: "calendar.ics",
mimeType: "text/calendar",
content: "BEGIN:VCALENDAR\nVERSION:2.0\nEND:VCALENDAR",
},
{
filename: "contact.vcf",
mimeType: "text/vcard",
content: "BEGIN:VCARD\nVERSION:4.0\nFN:Alice Example\nEND:VCARD",
},
])(
"renders concrete text artifact $filename through the global renderer path",
async ({ filename, mimeType, content }) => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve(content),
}),
);
const artifact = makeArtifact({
id: `text-${filename}`,
title: filename,
mimeType,
});
const classification = classifyArtifact(
artifact.mimeType,
artifact.title,
);
render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await screen.findByTestId("global-renderer");
expect(classification.type).toBe("text");
expect(vi.mocked(globalRegistry.getRenderer)).toHaveBeenCalledWith(
content,
expect.objectContaining({
filename,
mimeType,
}),
);
},
);
it("falls back to pre tag when no renderer matches", async () => {
const { globalRegistry } = await import(
"@/components/contextual/OutputRenderers"
);
const originalImpl = vi
.mocked(globalRegistry.getRenderer)
.getMockImplementation();
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: true,
text: () => Promise.resolve("raw content fallback"),
}),
);
const artifact = makeArtifact({
id: "fallback-pre-001",
title: "unknown.txt",
mimeType: "text/plain",
});
const classification = makeClassification({ type: "text" });
const { container } = render(
<ArtifactContent
artifact={artifact}
isSourceView={false}
classification={classification}
/>,
);
await waitFor(() => {
const pre = container.querySelector("pre");
expect(pre).toBeTruthy();
expect(pre?.textContent).toBe("raw content fallback");
});
// Restore
if (originalImpl) {
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
}
});
});

View File

@@ -3,6 +3,7 @@ import { renderHook, waitFor, act } from "@testing-library/react";
import {
useArtifactContent,
getCachedArtifactContent,
clearContentCache,
} from "../useArtifactContent";
import type { ArtifactRef } from "../../../../store";
import type { ArtifactClassification } from "../../helpers";
@@ -33,6 +34,7 @@ function makeClassification(
describe("useArtifactContent", () => {
beforeEach(() => {
clearContentCache();
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
@@ -44,6 +46,7 @@ describe("useArtifactContent", () => {
});
afterEach(() => {
clearContentCache();
vi.restoreAllMocks();
});
@@ -109,9 +112,12 @@ describe("useArtifactContent", () => {
useArtifactContent(artifact, classification),
);
await waitFor(() => {
expect(result.current.error).toBeTruthy();
});
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("404");
expect(result.current.content).toBeNull();
@@ -132,6 +138,176 @@ describe("useArtifactContent", () => {
expect(getCachedArtifactContent("cache-test")).toBe("file content here");
});
it("sets error on fetch failure for HTML artifacts (stale artifact)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
}),
);
const artifact = makeArtifact({ id: "stale-html-artifact" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("404");
expect(result.current.content).toBeNull();
});
it("sets error on network failure", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockRejectedValue(new Error("Network error")),
);
const artifact = makeArtifact({ id: "network-error-artifact" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("Network error");
expect(result.current.content).toBeNull();
});
it("retries transient HTML fetch failures before surfacing an error", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount < 3) {
return Promise.resolve({
ok: false,
status: 503,
headers: {
get: () => "application/json",
},
json: () => Promise.resolve({ detail: "temporary upstream error" }),
});
}
return Promise.resolve({
ok: true,
text: () => Promise.resolve("<html>ok now</html>"),
});
}),
);
const artifact = makeArtifact({ id: "transient-html-retry" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.content).toBe("<html>ok now</html>");
},
{ timeout: 2500 },
);
expect(callCount).toBe(3);
expect(result.current.error).toBeNull();
});
it("surfaces backend error detail from JSON responses", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 404,
headers: {
get: () => "application/json",
},
json: () => Promise.resolve({ detail: "File not found" }),
}),
);
const artifact = makeArtifact({ id: "json-error-detail" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("404");
expect(result.current.error).toContain("File not found");
});
it("retry after 404 on HTML artifact clears cache and re-fetches", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
ok: false,
status: 404,
text: () => Promise.resolve("Not found"),
});
}
return Promise.resolve({
ok: true,
text: () => Promise.resolve("<html>recovered</html>"),
});
}),
);
const artifact = makeArtifact({ id: "retry-html-artifact" });
const classification = makeClassification({ type: "html" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(() => {
expect(result.current.error).toBeTruthy();
});
act(() => {
result.current.retry();
});
await waitFor(
() => {
expect(result.current.content).toBe("<html>recovered</html>");
},
{ timeout: 2500 },
);
expect(result.current.error).toBeNull();
});
it("retry clears cache and re-fetches", async () => {
let callCount = 0;
vi.stubGlobal(
@@ -164,4 +340,162 @@ describe("useArtifactContent", () => {
expect(result.current.content).toBe("response 2");
});
});
// ── Non-transient errors ──────────────────────────────────────────
it("rejects immediately on 403 without retrying", async () => {
let callCount = 0;
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation(() => {
callCount++;
return Promise.resolve({
ok: false,
status: 403,
text: () => Promise.resolve("Forbidden"),
});
}),
);
const artifact = makeArtifact({ id: "forbidden-no-retry" });
const classification = makeClassification({ type: "text" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(callCount).toBe(1);
expect(result.current.error).toContain("403");
});
// ── Video skip-fetch ──────────────────────────────────────────────
it("skips fetch for video artifacts (like image)", async () => {
const artifact = makeArtifact({
id: "video-skip",
mimeType: "video/mp4",
});
const classification = makeClassification({ type: "video" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
expect(result.current.isLoading).toBe(false);
expect(result.current.content).toBeNull();
expect(result.current.pdfUrl).toBeNull();
expect(fetch).not.toHaveBeenCalled();
});
// ── PDF error paths ───────────────────────────────────────────────
it("sets error on PDF fetch failure (non-2xx)", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockResolvedValue({
ok: false,
status: 500,
text: () => Promise.resolve("Server Error"),
}),
);
const artifact = makeArtifact({ id: "pdf-error" });
const classification = makeClassification({ type: "pdf" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("500");
expect(result.current.pdfUrl).toBeNull();
});
it("sets error on PDF network failure", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockRejectedValue(new Error("PDF network failure")),
);
const artifact = makeArtifact({ id: "pdf-network-error" });
const classification = makeClassification({ type: "pdf" });
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(
() => {
expect(result.current.error).toBeTruthy();
},
{ timeout: 2500 },
);
expect(result.current.error).toContain("PDF network failure");
expect(result.current.pdfUrl).toBeNull();
});
// ── LRU cache eviction ────────────────────────────────────────────
it("evicts oldest entry when cache exceeds 12 items", async () => {
vi.stubGlobal(
"fetch",
vi.fn().mockImplementation((url: string) => {
const fileId = url.match(/files\/([^/]+)\/download/)?.[1] ?? "unknown";
return Promise.resolve({
ok: true,
text: () => Promise.resolve(`content-${fileId}`),
});
}),
);
const classification = makeClassification({ type: "text" });
// Fill the cache with 12 entries (cache max = 12)
for (let i = 0; i < 12; i++) {
const artifact = makeArtifact({
id: `lru-${i}`,
sourceUrl: `/api/proxy/api/workspace/files/lru-${i}/download`,
});
const { result } = renderHook(() =>
useArtifactContent(artifact, classification),
);
await waitFor(() => {
expect(result.current.isLoading).toBe(false);
});
}
// All 12 should be cached
expect(getCachedArtifactContent("lru-0")).toBe("content-lru-0");
expect(getCachedArtifactContent("lru-11")).toBe("content-lru-11");
// Adding a 13th should evict lru-0 (the oldest)
const artifact13 = makeArtifact({
id: "lru-12",
sourceUrl: "/api/proxy/api/workspace/files/lru-12/download",
});
const { result: result13 } = renderHook(() =>
useArtifactContent(artifact13, classification),
);
await waitFor(() => {
expect(result13.current.isLoading).toBe(false);
});
expect(getCachedArtifactContent("lru-0")).toBeUndefined();
expect(getCachedArtifactContent("lru-1")).toBe("content-lru-1");
expect(getCachedArtifactContent("lru-12")).toBe("content-lru-12");
});
});

View File

@@ -85,4 +85,35 @@ describe("buildReactArtifactSrcDoc", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("box-sizing: border-box");
});
it("supports a named previewProps export in the runtime", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("moduleExports.previewProps");
expect(doc).toContain("React.createElement(Component, previewProps || {})");
});
it("includes a helpful message for components that expect props", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("This component appears to expect props.");
expect(doc).toContain("previewProps");
});
it("checks componentExpectsProps on the raw component before wrapping", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("RawComponent.length > 0");
expect(doc).toContain("wrapWithProviders(RawComponent");
});
it("wrapWithProviders forwards props to the wrapped component", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain("function WrappedArtifactPreview(props)");
expect(doc).toContain("React.createElement(Component, props)");
});
it("supports named exported components and provider wrappers in the runtime", () => {
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
expect(doc).toContain('name.endsWith("Provider")');
expect(doc).toContain("/^[A-Z]/.test(name)");
expect(doc).toContain("wrapWithProviders");
});
});

View File

@@ -169,8 +169,8 @@ export function buildReactArtifactSrcDoc(
return Component;
}
return function WrappedArtifactPreview() {
let tree = React.createElement(Component);
return function WrappedArtifactPreview(props) {
let tree = React.createElement(Component, props);
for (let i = providers.length - 1; i >= 0; i -= 1) {
tree = React.createElement(providers[i], null, tree);
@@ -180,6 +180,17 @@ export function buildReactArtifactSrcDoc(
};
}
function getPreviewProps(moduleExports) {
if (
moduleExports.previewProps &&
typeof moduleExports.previewProps === "object"
) {
return moduleExports.previewProps;
}
return null;
}
function require(name) {
if (name === "react") {
return React;
@@ -235,6 +246,11 @@ export function buildReactArtifactSrcDoc(
render() {
if (this.state.error) {
const propsHelp =
this.props.componentExpectsProps && !this.props.hasPreviewProps
? "\\n\\nThis component appears to expect props. Export a named previewProps object with sample values to render it in artifact preview."
: "";
return React.createElement(
"div",
{
@@ -249,7 +265,9 @@ export function buildReactArtifactSrcDoc(
whiteSpace: "pre-wrap",
},
},
this.state.error.stack || this.state.error.message || String(this.state.error),
(this.state.error.stack ||
this.state.error.message ||
String(this.state.error)) + propsHelp,
);
}
@@ -296,16 +314,19 @@ export function buildReactArtifactSrcDoc(
moduleExports.App = executionResult.app;
}
const Component = wrapWithProviders(
getRenderableCandidate(moduleExports),
moduleExports,
);
const RawComponent = getRenderableCandidate(moduleExports);
const componentExpectsProps = RawComponent.length > 0;
const Component = wrapWithProviders(RawComponent, moduleExports);
const previewProps = getPreviewProps(moduleExports);
ReactDOM.createRoot(rootElement).render(
React.createElement(
PreviewErrorBoundary,
null,
React.createElement(Component),
{
componentExpectsProps: componentExpectsProps,
hasPreviewProps: previewProps != null,
},
React.createElement(Component, previewProps || {}),
),
);
} catch (error) {

View File

@@ -48,4 +48,104 @@ describe("transpileReactArtifactSource", () => {
expect(out).not.toContain(": string");
expect(out).toContain("function greet(name)");
});
it("transpiles a concrete props-based artifact with previewProps", async () => {
const src = `
import React, { FC, useState } from "react";
interface ArtifactFile {
id: string;
name: string;
mimeType: string;
url: string;
sizeBytes: number;
}
interface Props {
files: ArtifactFile[];
onSelect: (file: ArtifactFile) => void;
}
export const previewProps: Props = {
files: [
{
id: "1",
name: "report.png",
mimeType: "image/png",
url: "/report.png",
sizeBytes: 2048,
},
],
onSelect: () => {},
};
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
const [selected, setSelected] = useState<string | null>(null);
const handleClick = (file: ArtifactFile) => {
setSelected(file.id);
onSelect(file);
};
return (
<ul>
{files.map((file) => (
<li key={file.id} onClick={() => handleClick(file)}>
<span>{selected === file.id ? "selected" : file.name}</span>
</li>
))}
</ul>
);
};
export default ArtifactList;
`;
const out = await transpileReactArtifactSource(src, "ArtifactList.tsx");
expect(out).toContain("exports.previewProps");
expect(out).toContain("exports.default = ArtifactList");
expect(out).toContain("useState");
expect(out).not.toContain("interface Props");
expect(out).not.toContain("interface ArtifactFile");
});
it("transpiles a named export artifact without a default export", async () => {
const src = `
export function ResultsGrid() {
return (
<section>
<h1>Results</h1>
<p>Named export preview</p>
</section>
);
}
`;
const out = await transpileReactArtifactSource(src, "ResultsGrid.tsx");
expect(out).toContain("exports.ResultsGrid = ResultsGrid");
expect(out).toMatch(/\.createElement\(/);
expect(out).not.toContain("<section>");
});
it("transpiles a provider-wrapped artifact with separate provider and component exports", async () => {
const src = `
import React from "react";
export function DemoProvider({ children }: { children: React.ReactNode }) {
return <div data-theme="demo">{children}</div>;
}
export function DashboardCard() {
return <main>Provider-backed preview</main>;
}
`;
const out = await transpileReactArtifactSource(src, "DashboardCard.tsx");
expect(out).toContain("exports.DemoProvider = DemoProvider");
expect(out).toContain("exports.DashboardCard = DashboardCard");
expect(out).not.toContain("React.ReactNode");
});
});

View File

@@ -7,12 +7,116 @@ import type { ArtifactClassification } from "../helpers";
// Cap on cached text artifacts. Long sessions with many large artifacts
// would otherwise hold every opened one in memory.
const CONTENT_CACHE_MAX = 12;
const CONTENT_FETCH_MAX_RETRIES = 2;
const CONTENT_FETCH_RETRY_DELAY_MS = 500;
// Module-level LRU keyed by artifact id so a sibling action (e.g. Copy
// in ArtifactPanelHeader) can read what the panel already fetched without
// re-hitting the network.
const contentCache = new Map<string, string>();
class ArtifactFetchError extends Error {}
function isTransientArtifactFetchStatus(status: number): boolean {
return status === 408 || status === 429 || status >= 500;
}
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
function getArtifactErrorMessage(body: unknown): string | null {
if (typeof body === "string") {
const trimmed = body.replace(/\s+/g, " ").trim();
return trimmed || null;
}
if (!body || typeof body !== "object") return null;
if (
"detail" in body &&
typeof body.detail === "string" &&
body.detail.trim().length > 0
) {
return body.detail.trim();
}
if (
"error" in body &&
typeof body.error === "string" &&
body.error.trim().length > 0
) {
return body.error.trim();
}
if (
"detail" in body &&
body.detail &&
typeof body.detail === "object" &&
"message" in body.detail &&
typeof body.detail.message === "string" &&
body.detail.message.trim().length > 0
) {
return body.detail.message.trim();
}
return null;
}
async function parseArtifactFetchError(response: Response): Promise<string> {
const prefix = `Failed to fetch: ${response.status}`;
const contentType =
response.headers?.get?.("content-type")?.toLowerCase() ?? "";
try {
if (
contentType.includes("application/json") &&
typeof response.json === "function"
) {
const body = await response.json();
const detail = getArtifactErrorMessage(body);
return detail ? `${prefix} ${detail}` : prefix;
}
if (typeof response.text === "function") {
const text = await response.text();
const detail = getArtifactErrorMessage(text);
return detail ? `${prefix} ${detail}` : prefix;
}
} catch {
return prefix;
}
return prefix;
}
async function fetchArtifactResponse(url: string): Promise<Response> {
for (let attempt = 0; attempt <= CONTENT_FETCH_MAX_RETRIES; attempt++) {
try {
const response = await fetch(url);
if (response.ok) return response;
if (
!isTransientArtifactFetchStatus(response.status) ||
attempt === CONTENT_FETCH_MAX_RETRIES
) {
throw new ArtifactFetchError(await parseArtifactFetchError(response));
}
} catch (error) {
if (error instanceof ArtifactFetchError) throw error;
if (attempt === CONTENT_FETCH_MAX_RETRIES) {
throw error instanceof Error
? error
: new Error("Failed to fetch artifact");
}
}
await sleep(CONTENT_FETCH_RETRY_DELAY_MS);
}
throw new Error("Failed to fetch artifact");
}
export function getCachedArtifactContent(id: string): string | undefined {
return contentCache.get(id);
}
@@ -64,7 +168,7 @@ export function useArtifactContent(
}, [artifact.id, isLoading]);
useEffect(() => {
if (classification.type === "image") {
if (classification.type === "image" || classification.type === "video") {
setContent(null);
setPdfUrl(null);
setError(null);
@@ -80,11 +184,8 @@ export function useArtifactContent(
let objectUrl: string | null = null;
setContent(null);
setPdfUrl(null);
fetch(artifact.sourceUrl)
.then((res) => {
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
return res.blob();
})
fetchArtifactResponse(artifact.sourceUrl)
.then((res) => res.blob())
.then((blob) => {
objectUrl = URL.createObjectURL(blob);
if (cancelled) {
@@ -121,11 +222,8 @@ export function useArtifactContent(
cancelled = true;
};
}
fetch(artifact.sourceUrl)
.then((res) => {
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
return res.text();
})
fetchArtifactResponse(artifact.sourceUrl)
.then((res) => res.text())
.then((text) => {
if (!cancelled) {
if (cache.size >= CONTENT_CACHE_MAX) {

View File

@@ -1,5 +1,31 @@
import type { ArtifactRef } from "../../store";
const MAX_RETRIES = 2;
const RETRY_DELAY_MS = 500;
function isTransientError(status: number): boolean {
return status >= 500 || status === 408 || status === 429;
}
class DownloadError extends Error {}
async function fetchWithRetry(url: string, retries: number): Promise<Response> {
for (let attempt = 0; attempt <= retries; attempt++) {
try {
const res = await fetch(url);
if (res.ok) return res;
if (!isTransientError(res.status) || attempt === retries) {
throw new DownloadError(`Download failed: ${res.status}`);
}
} catch (error) {
if (error instanceof DownloadError) throw error;
if (attempt === retries) throw error;
}
await new Promise((r) => setTimeout(r, RETRY_DELAY_MS));
}
throw new Error("Unreachable");
}
/**
* Trigger a file download from an artifact URL.
*
@@ -7,26 +33,28 @@ import type { ArtifactRef } from "../../store";
* ignores the `download` attribute on cross-origin responses (GCS signed
* URLs), and some browsers require the anchor to be attached to the DOM
* before `.click()` fires the download.
*
* Retries up to {@link MAX_RETRIES} times on transient server errors (5xx,
* 408, 429) to handle intermittent proxy/GCS failures.
*/
export function downloadArtifact(artifact: ArtifactRef): Promise<void> {
// Replace path separators, Windows-reserved chars, control chars, and
// parent-dir sequences so the browser-assigned filename is safe to write
// anywhere on the user's filesystem.
const safeName =
artifact.title
.replace(/\.\./g, "_")
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
.replace(/^\.+/, "") || "download";
return fetch(artifact.sourceUrl)
.then((res) => {
if (!res.ok) throw new Error(`Download failed: ${res.status}`);
return res.blob();
})
const collapsedDots = artifact.title.replace(/\.\./g, "");
const hasVisibleName = collapsedDots.replace(/^\.+/, "").length > 0;
const safeName = artifact.title
.replace(/\.\./g, "_")
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
.replace(/^\.+/, "");
return fetchWithRetry(artifact.sourceUrl, MAX_RETRIES)
.then((res) => res.blob())
.then((blob) => {
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = safeName;
a.download = safeName && hasVisibleName ? safeName : "download";
document.body.appendChild(a);
a.click();
a.remove();

View File

@@ -56,7 +56,7 @@ describe("classifyArtifact", () => {
expect(classifyArtifact("application/octet-stream", "x").openable).toBe(
false,
);
expect(classifyArtifact("video/mp4", "clip.mp4").openable).toBe(false);
expect(classifyArtifact("audio/mpeg", "track.mp3").openable).toBe(false);
});
it("defaults unknown extension+MIME to download-only (not text)", () => {
@@ -76,4 +76,398 @@ describe("classifyArtifact", () => {
const c = classifyArtifact("text/plain", "data.csv");
expect(c.type).toBe("csv");
});
it("classifies video/mp4 as video (previewable)", () => {
const c = classifyArtifact("video/mp4", "clip.mp4");
expect(c.type).toBe("video");
expect(c.openable).toBe(true);
});
it("classifies video/webm as video (previewable)", () => {
const c = classifyArtifact("video/webm", "clip.webm");
expect(c.type).toBe("video");
expect(c.openable).toBe(true);
});
// ── Extension coverage ────────────────────────────────────────────
it("routes .htm as html (not just .html)", () => {
const c = classifyArtifact(null, "page.htm");
expect(c.type).toBe("html");
expect(c.hasSourceToggle).toBe(true);
});
it("routes .json as json with source toggle", () => {
const c = classifyArtifact(null, "config.json");
expect(c.type).toBe("json");
expect(c.hasSourceToggle).toBe(true);
});
it("routes .txt as text", () => {
expect(classifyArtifact(null, "notes.txt").type).toBe("text");
});
it("routes .log as text", () => {
expect(classifyArtifact(null, "server.log").type).toBe("text");
});
it("routes .mdx as markdown", () => {
expect(classifyArtifact(null, "docs.mdx").type).toBe("markdown");
});
it("routes browser-safe video extensions to video", () => {
for (const ext of [".mp4", ".webm", ".m4v"]) {
const c = classifyArtifact(null, `clip${ext}`);
expect(c.type).toBe("video");
expect(c.openable).toBe(true);
}
});
it("keeps legacy or unsupported video extensions download-only", () => {
for (const ext of [".ogg", ".mov", ".avi", ".mkv", ".flv", ".mpeg"]) {
const c = classifyArtifact(null, `clip${ext}`);
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
}
});
it("routes all code extensions to code", () => {
const codeExts = [
"main.js",
"app.ts",
"theme.scss",
"legacy.less",
"schema.graphql",
"query.gql",
"api.proto",
"main.dart",
"lib.rb",
"server.rs",
"App.java",
"main.c",
"util.cpp",
"header.h",
"Program.cs",
"index.php",
"main.swift",
"App.kt",
"run.sh",
"start.bash",
"prompt.zsh",
"config.toml",
"settings.ini",
"app.cfg",
"query.sql",
"analysis.r",
"game.lua",
"script.pl",
"Calc.scala",
];
for (const file of codeExts) {
expect(classifyArtifact(null, file).type).toBe("code");
}
});
it("routes config filenames and extensions to code", () => {
const configFiles = [
".env",
".env.local",
"app.properties",
"service.conf",
".gitignore",
"Dockerfile",
"Makefile",
];
for (const file of configFiles) {
expect(classifyArtifact(null, file).type).toBe("code");
}
});
it("routes .jsonl as code for now", () => {
const c = classifyArtifact(null, "events.jsonl");
expect(c.type).toBe("code");
});
it("routes .tsv as csv/spreadsheet", () => {
const c = classifyArtifact(null, "table.tsv");
expect(c.type).toBe("csv");
expect(c.hasSourceToggle).toBe(true);
});
it("routes .ics and .vcf as text", () => {
expect(classifyArtifact(null, "calendar.ics").type).toBe("text");
expect(classifyArtifact(null, "contact.vcf").type).toBe("text");
});
it("routes all image extensions to image", () => {
for (const ext of [".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".ico"]) {
expect(classifyArtifact(null, `file${ext}`).type).toBe("image");
}
});
// ── MIME fallback coverage ────────────────────────────────────────
it("routes application/json MIME to json", () => {
const c = classifyArtifact("application/json", "noext");
expect(c.type).toBe("json");
});
it("routes text/x-* MIME prefix to code", () => {
expect(classifyArtifact("text/x-python", "noext").type).toBe("code");
expect(classifyArtifact("text/x-c", "noext").type).toBe("code");
expect(classifyArtifact("text/x-java-source", "noext").type).toBe("code");
});
it("routes react MIME types to react", () => {
expect(classifyArtifact("text/jsx", "noext").type).toBe("react");
expect(classifyArtifact("text/tsx", "noext").type).toBe("react");
expect(classifyArtifact("application/jsx", "noext").type).toBe("react");
expect(classifyArtifact("application/x-typescript-jsx", "noext").type).toBe(
"react",
);
});
it("routes JavaScript/TypeScript MIME to code", () => {
expect(classifyArtifact("application/javascript", "noext").type).toBe(
"code",
);
expect(classifyArtifact("text/javascript", "noext").type).toBe("code");
expect(classifyArtifact("application/typescript", "noext").type).toBe(
"code",
);
expect(classifyArtifact("text/typescript", "noext").type).toBe("code");
});
it("routes XML MIME to code", () => {
expect(classifyArtifact("application/xml", "noext").type).toBe("code");
expect(classifyArtifact("text/xml", "noext").type).toBe("code");
});
it("routes text/x-markdown MIME to markdown", () => {
expect(classifyArtifact("text/x-markdown", "noext").type).toBe("markdown");
});
it("routes text/csv MIME to csv", () => {
expect(classifyArtifact("text/csv", "noext").type).toBe("csv");
});
it("routes TSV MIME to csv", () => {
expect(classifyArtifact("text/tab-separated-values", "noext").type).toBe(
"csv",
);
});
it("routes unknown text/* MIME to text (not download-only)", () => {
expect(classifyArtifact("text/rtf", "noext").type).toBe("text");
});
it("routes browser-safe image MIME types to image", () => {
expect(classifyArtifact("image/avif", "noext").type).toBe("image");
});
it("keeps unsupported image MIME types download-only", () => {
for (const mime of [
"image/tiff",
"image/x-portable-pixmap",
"image/x-portable-graymap",
]) {
const c = classifyArtifact(mime, "noext");
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
}
});
it("routes browser-safe video MIME types to video", () => {
expect(classifyArtifact("video/mp4", "noext").type).toBe("video");
expect(classifyArtifact("video/webm", "noext").type).toBe("video");
});
it("keeps legacy or unsupported video MIME types download-only", () => {
for (const mime of [
"video/x-msvideo",
"video/x-flv",
"video/mpeg",
"video/quicktime",
"video/x-matroska",
"video/ogg",
]) {
const c = classifyArtifact(mime, "noext");
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
}
});
// ── BINARY_MIMES coverage ────────────────────────────────────────
it("treats all BINARY_MIMES entries as download-only", () => {
const binaryMimes = [
"application/zip",
"application/x-zip-compressed",
"application/gzip",
"application/x-tar",
"application/x-rar-compressed",
"application/x-7z-compressed",
"application/octet-stream",
"application/x-executable",
"application/x-msdos-program",
"application/vnd.microsoft.portable-executable",
];
for (const mime of binaryMimes) {
const c = classifyArtifact(mime, "noext");
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
}
});
it("treats audio/* MIME as download-only", () => {
expect(classifyArtifact("audio/mpeg", "noext").openable).toBe(false);
expect(classifyArtifact("audio/wav", "noext").openable).toBe(false);
expect(classifyArtifact("audio/ogg", "noext").openable).toBe(false);
});
// ── Size gate edge cases ──────────────────────────────────────────
it("does NOT gate files at exactly 10MB (boundary is >10MB)", () => {
const tenMB = 10 * 1024 * 1024;
const c = classifyArtifact("text/plain", "exact.txt", tenMB);
expect(c.type).toBe("text");
expect(c.openable).toBe(true);
});
it("gates files at 10MB + 1 byte", () => {
const overTenMB = 10 * 1024 * 1024 + 1;
const c = classifyArtifact("text/plain", "big.txt", overTenMB);
expect(c.type).toBe("download-only");
expect(c.openable).toBe(false);
});
it("does not gate when sizeBytes is 0", () => {
const c = classifyArtifact("text/plain", "empty.txt", 0);
expect(c.type).toBe("text");
expect(c.openable).toBe(true);
});
it("does not gate when sizeBytes is undefined", () => {
const c = classifyArtifact("text/plain", "file.txt", undefined);
expect(c.type).toBe("text");
expect(c.openable).toBe(true);
});
// ── Extension over MIME priority ──────────────────────────────────
it("extension wins over MIME for JSON (MIME says text, ext says json)", () => {
const c = classifyArtifact("text/plain", "data.json");
expect(c.type).toBe("json");
});
it("extension wins over MIME for markdown", () => {
const c = classifyArtifact("text/plain", "README.md");
expect(c.type).toBe("markdown");
});
// ── Null/missing inputs ───────────────────────────────────────────
it("handles null MIME with no filename as download-only", () => {
const c = classifyArtifact(null, undefined);
expect(c.type).toBe("download-only");
});
it("handles null MIME with empty filename as download-only", () => {
const c = classifyArtifact(null, "");
expect(c.type).toBe("download-only");
});
it("handles known config files with no extension", () => {
const c = classifyArtifact(null, "Makefile");
expect(c.type).toBe("code");
});
// ── Exotic/compound extensions must NOT open the side panel ───────
// These are real file types agents might produce. Every single one
// must be download-only so we never try to render binary garbage.
it("does not open .tar.gz (compound extension takes last segment)", () => {
// getExtension("archive.tar.gz") → ".gz" which is not in EXT_KIND
const c = classifyArtifact(null, "archive.tar.gz");
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
});
it("does not open .tar.bz2", () => {
const c = classifyArtifact(null, "archive.tar.bz2");
expect(c.openable).toBe(false);
});
it("does not open .tar.xz", () => {
const c = classifyArtifact(null, "archive.tar.xz");
expect(c.openable).toBe(false);
});
it("does not open common binary formats", () => {
const binaries = [
"setup.exe",
"library.dll",
"image.iso",
"installer.dmg",
"package.deb",
"package.rpm",
"module.wasm",
"Main.class",
"module.pyc",
"app.apk",
"game.pak",
"model.onnx",
"weights.pt",
"data.parquet",
"archive.rar",
"archive.7z",
"disk.vhd",
"disk.vmdk",
"firmware.bin",
"core.dump",
"database.sqlite",
"database.db",
"index.idx",
];
for (const file of binaries) {
const c = classifyArtifact(null, file);
expect(c.openable).toBe(false);
}
});
it("does not open binary MIME types even with a misleading extension", () => {
// Extension is unknown, MIME is binary
const c = classifyArtifact("application/x-executable", "run.elf");
expect(c.openable).toBe(false);
});
it("does not open files with random/made-up extensions", () => {
const weirdExts = [
"output.xyz",
"data.foo",
"file.asdf",
"thing.blargh",
"result.out",
"x.1234",
];
for (const file of weirdExts) {
const c = classifyArtifact(null, file);
expect(c.openable).toBe(false);
expect(c.type).toBe("download-only");
}
});
it("does not open font files", () => {
for (const file of ["sans.ttf", "serif.otf", "icon.woff", "icon.woff2"]) {
expect(classifyArtifact(null, file).openable).toBe(false);
}
});
it("does not open certificate/key files", () => {
// .pem and .key have no extension mapping and null MIME → download-only
for (const file of ["cert.pem", "server.key", "ca.crt", "id.p12"]) {
expect(classifyArtifact(null, file).openable).toBe(false);
}
});
});

View File

@@ -5,6 +5,7 @@ import {
FileText,
Image,
Table,
VideoCamera,
} from "@phosphor-icons/react";
import type { Icon } from "@phosphor-icons/react";
@@ -17,6 +18,7 @@ export interface ArtifactClassification {
| "csv"
| "json"
| "image"
| "video"
| "pdf"
| "text"
| "download-only";
@@ -38,6 +40,13 @@ const KIND: Record<string, ArtifactClassification> = {
openable: true,
hasSourceToggle: false,
},
video: {
type: "video",
icon: VideoCamera,
label: "Video",
openable: true,
hasSourceToggle: false,
},
pdf: {
type: "pdf",
icon: FileText,
@@ -113,8 +122,13 @@ const EXT_KIND: Record<string, string> = {
".svg": "image",
".bmp": "image",
".ico": "image",
".avif": "image",
".mp4": "video",
".webm": "video",
".m4v": "video",
".pdf": "pdf",
".csv": "csv",
".tsv": "csv",
".html": "html",
".htm": "html",
".jsx": "react",
@@ -122,11 +136,17 @@ const EXT_KIND: Record<string, string> = {
".md": "markdown",
".mdx": "markdown",
".json": "json",
".jsonl": "code",
".txt": "text",
".log": "text",
".ics": "text",
".vcf": "text",
".env": "code",
".gitignore": "code",
// code extensions
".js": "code",
".ts": "code",
".dart": "code",
".py": "code",
".rb": "code",
".go": "code",
@@ -142,11 +162,19 @@ const EXT_KIND: Record<string, string> = {
".sh": "code",
".bash": "code",
".zsh": "code",
".scss": "code",
".sass": "code",
".less": "code",
".graphql": "code",
".gql": "code",
".proto": "code",
".yml": "code",
".yaml": "code",
".toml": "code",
".ini": "code",
".cfg": "code",
".conf": "code",
".properties": "code",
".sql": "code",
".r": "code",
".lua": "code",
@@ -154,10 +182,16 @@ const EXT_KIND: Record<string, string> = {
".scala": "code",
};
const EXACT_FILENAME_KIND: Record<string, string> = {
dockerfile: "code",
makefile: "code",
};
// Exact-match MIME → kind (fallback when extension doesn't match).
const MIME_KIND: Record<string, string> = {
"application/pdf": "pdf",
"text/csv": "csv",
"text/tab-separated-values": "csv",
"text/html": "html",
"text/jsx": "react",
"text/tsx": "react",
@@ -166,6 +200,9 @@ const MIME_KIND: Record<string, string> = {
"text/markdown": "markdown",
"text/x-markdown": "markdown",
"application/json": "json",
"application/x-ndjson": "code",
"application/ndjson": "code",
"application/jsonl": "code",
"application/javascript": "code",
"text/javascript": "code",
"application/typescript": "code",
@@ -182,11 +219,37 @@ const BINARY_MIMES = new Set([
"application/x-rar-compressed",
"application/x-7z-compressed",
"application/octet-stream",
"application/wasm",
"application/x-executable",
"application/x-msdos-program",
"application/vnd.microsoft.portable-executable",
]);
const PREVIEWABLE_IMAGE_MIMES = new Set([
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/svg+xml",
"image/bmp",
"image/x-icon",
"image/vnd.microsoft.icon",
"image/avif",
]);
const PREVIEWABLE_VIDEO_MIMES = new Set([
"video/mp4",
"video/webm",
"video/x-m4v",
]);
function getBasename(filename?: string): string {
if (!filename) return "";
const normalized = filename.replace(/\\/g, "/");
const parts = normalized.split("/");
return parts[parts.length - 1]?.toLowerCase() ?? "";
}
function getExtension(filename?: string): string {
if (!filename) return "";
const lastDot = filename.lastIndexOf(".");
@@ -202,24 +265,36 @@ export function classifyArtifact(
// Size gate: >10MB is download-only regardless of type.
if (sizeBytes && sizeBytes > TEN_MB) return KIND["download-only"];
const basename = getBasename(filename);
const exactKind = EXACT_FILENAME_KIND[basename];
if (exactKind) return KIND[exactKind];
if (basename === ".env" || basename.startsWith(".env.")) {
return KIND.code;
}
// Extension first (more reliable than MIME for AI-generated files).
const ext = getExtension(filename);
const ext = getExtension(basename);
const extKind = EXT_KIND[ext];
if (extKind) return KIND[extKind];
// MIME fallbacks.
const mime = (mimeType ?? "").toLowerCase();
if (mime.startsWith("image/")) return KIND.image;
if (PREVIEWABLE_IMAGE_MIMES.has(mime)) return KIND.image;
if (PREVIEWABLE_VIDEO_MIMES.has(mime)) return KIND.video;
const mimeKind = MIME_KIND[mime];
if (mimeKind) return KIND[mimeKind];
if (mime.startsWith("text/x-")) return KIND.code;
if (
BINARY_MIMES.has(mime) ||
mime.startsWith("audio/") ||
mime.startsWith("video/")
mime.startsWith("image/") ||
mime.startsWith("video/") ||
mime.startsWith("font/")
) {
return KIND["download-only"];
}
if (BINARY_MIMES.has(mime) || mime.startsWith("audio/")) {
return KIND["download-only"];
}
if (mime.startsWith("text/")) return KIND.text;
// Unknown extension + unknown MIME: don't open — we can't safely assume

View File

@@ -83,6 +83,7 @@ export function useArtifactPanel() {
const canCopy =
classification != null &&
classification.type !== "image" &&
classification.type !== "video" &&
classification.type !== "download-only" &&
classification.type !== "pdf";

View File

@@ -64,10 +64,7 @@ export const ChatContainer = ({
// open state drive layout width; an artifact generated in a stale session
// state would otherwise shrink the chat column with no panel rendered.
const isArtifactOpen = isArtifactsEnabled && isArtifactPanelOpen;
useAutoOpenArtifacts({
messages: isArtifactsEnabled ? messages : [],
sessionId,
});
useAutoOpenArtifacts({ sessionId });
const isBusy =
status === "streaming" ||
status === "submitted" ||

View File

@@ -0,0 +1,77 @@
import { describe, expect, it, beforeEach, afterEach } from "vitest";
import { renderHook } from "@testing-library/react";
import { useAutoOpenArtifacts } from "../useAutoOpenArtifacts";
import { useCopilotUIStore } from "../../../store";
// Capture the real store actions before any test can replace them.
const realOpenArtifact = useCopilotUIStore.getState().openArtifact;
const realResetArtifactPanel = useCopilotUIStore.getState().resetArtifactPanel;
function resetStore() {
useCopilotUIStore.setState({
openArtifact: realOpenArtifact,
resetArtifactPanel: realResetArtifactPanel,
artifactPanel: {
isOpen: false,
isMinimized: false,
isMaximized: false,
width: 600,
activeArtifact: null,
history: [],
},
});
}
describe("useAutoOpenArtifacts", () => {
beforeEach(resetStore);
afterEach(resetStore);
it("does not auto-open artifacts on initial message load", () => {
renderHook(() => useAutoOpenArtifacts({ sessionId: "session-1" }));
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("does not auto-open when rerendering within the same session", () => {
const { rerender } = renderHook(
({ sessionId }: { sessionId: string }) =>
useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "session-1" } },
);
rerender({ sessionId: "session-1" });
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("panel should fully reset when session changes", () => {
const artifact = {
id: "file1",
title: "image.png",
mimeType: "image/png",
sourceUrl: "/api/proxy/api/workspace/files/file1/download",
origin: "agent" as const,
};
useCopilotUIStore.getState().openArtifact(artifact);
useCopilotUIStore.getState().openArtifact({
...artifact,
id: "file2",
title: "second.png",
sourceUrl: "/api/proxy/api/workspace/files/file2/download",
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
const { rerender } = renderHook(
({ sessionId }: { sessionId: string }) =>
useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "session-1" } },
);
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
rerender({ sessionId: "session-2" });
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
});

View File

@@ -3,17 +3,19 @@ import { beforeEach, describe, expect, it } from "vitest";
import { useCopilotUIStore } from "../../store";
import { useAutoOpenArtifacts } from "./useAutoOpenArtifacts";
function assistantMessageWithText(id: string, text: string) {
return {
id,
role: "assistant" as const,
parts: [{ type: "text" as const, text }],
};
}
const A_ID = "11111111-0000-0000-0000-000000000000";
const B_ID = "22222222-0000-0000-0000-000000000000";
function makeArtifact(id: string, title = `${id}.txt`) {
return {
id,
title,
mimeType: "text/plain",
sourceUrl: `/api/proxy/api/workspace/files/${id}/download`,
origin: "agent" as const,
};
}
function resetStore() {
useCopilotUIStore.setState({
artifactPanel: {
@@ -30,111 +32,60 @@ function resetStore() {
describe("useAutoOpenArtifacts", () => {
beforeEach(resetStore);
it("does NOT auto-open on the initial hydration of message list (baseline pass)", () => {
const messages = [
assistantMessageWithText("m1", `[a](workspace://${A_ID})`),
];
renderHook(() =>
useAutoOpenArtifacts({ messages: messages as any, sessionId: "s1" }),
);
// Initial run just records the baseline fingerprint; nothing opens.
it("does not auto-open on initial render", () => {
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("auto-opens when an existing assistant message adds a new artifact", () => {
// 1st render: baseline with no artifact.
const initial = [assistantMessageWithText("m1", "thinking...")];
it("does not auto-open when rerendering within the same session", () => {
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{ initialProps: { messages: initial, sessionId: "s1" } },
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "s1" } },
);
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
// 2nd render: same message id now contains an artifact link.
act(() => {
rerender({
messages: [
assistantMessageWithText("m1", `here: [A](workspace://${A_ID})`),
],
sessionId: "s1",
});
rerender({ sessionId: "s1" });
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("resets the panel state when sessionId changes", () => {
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
const { rerender } = renderHook(
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "s1" } },
);
act(() => {
rerender({ sessionId: "s2" });
});
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(true);
expect(s.activeArtifact?.id).toBe(A_ID);
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
it("does not re-open when the fingerprint hasn't changed", () => {
const msg = assistantMessageWithText("m1", `[A](workspace://${A_ID})`);
it("does not carry a stale back stack into the next session", () => {
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{ initialProps: { messages: [msg], sessionId: "s1" } },
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
{ initialProps: { sessionId: "s1" } },
);
// Baseline captured; no open.
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
// Rerender identical content: no change in fingerprint → no open.
act(() => {
rerender({ messages: [msg], sessionId: "s1" });
rerender({ sessionId: "s2" });
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
});
it("auto-opens when a brand-new assistant message arrives after the baseline is established", () => {
// First render: one message without artifacts → establishes baseline.
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{
initialProps: {
messages: [assistantMessageWithText("m1", "plain")] as any,
sessionId: "s1",
},
},
);
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
useCopilotUIStore.getState().openArtifact(makeArtifact("c", "c.txt"));
// Second render: a *new* assistant message with an artifact. Baseline
// is already set, so this should auto-open.
act(() => {
rerender({
messages: [
assistantMessageWithText("m1", "plain"),
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
] as any,
sessionId: "s1",
});
});
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(true);
expect(s.activeArtifact?.id).toBe(B_ID);
});
it("resets hydration baseline when sessionId changes", () => {
const { rerender } = renderHook(
({ messages, sessionId }) =>
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
{
initialProps: {
messages: [
assistantMessageWithText("m1", `[A](workspace://${A_ID})`),
] as any,
sessionId: "s1",
},
},
);
// Switch to a new session — the first pass on the new session should
// NOT auto-open (it's a fresh hydration).
act(() => {
rerender({
messages: [
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
] as any,
sessionId: "s2",
});
});
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
expect(s.activeArtifact?.id).toBe("c");
expect(s.history).toEqual([]);
});
});

View File

@@ -1,91 +1,29 @@
"use client";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useRef } from "react";
import type { ArtifactRef } from "../../store";
import { useCopilotUIStore } from "../../store";
import { getMessageArtifacts } from "../ChatMessagesContainer/helpers";
function fingerprintArtifacts(artifacts: ArtifactRef[]): string {
return artifacts
.map((a) => `${a.id}:${a.title}:${a.mimeType ?? ""}:${a.sourceUrl}`)
.join("|");
}
interface UseAutoOpenArtifactsOptions {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
sessionId: string | null;
}
export function useAutoOpenArtifacts({
messages,
sessionId,
}: UseAutoOpenArtifactsOptions) {
const openArtifact = useCopilotUIStore((state) => state.openArtifact);
const messageFingerprintsRef = useRef<Map<string, string>>(new Map());
const hasInitializedRef = useRef(false);
const resetArtifactPanel = useCopilotUIStore(
(state) => state.resetArtifactPanel,
);
const prevSessionIdRef = useRef(sessionId);
useEffect(() => {
messageFingerprintsRef.current = new Map();
hasInitializedRef.current = false;
}, [sessionId]);
const isSessionChange = prevSessionIdRef.current !== sessionId;
prevSessionIdRef.current = sessionId;
useEffect(() => {
if (messages.length === 0) {
messageFingerprintsRef.current = new Map();
return;
// Artifact previews should open only from an explicit user click.
// When the session changes, fully clear the panel state so stale
// active artifacts and back-stack entries never bleed into the next chat.
if (isSessionChange) {
resetArtifactPanel();
}
// Only scan messages whose fingerprint might have changed since the
// last pass: that's the last assistant message (currently streaming)
// plus any assistant message whose id isn't in the baseline yet.
// This keeps the cost O(new+tail), not O(all messages), on every chunk.
const previous = messageFingerprintsRef.current;
const nextFingerprints = new Map<string, string>(previous);
let nextArtifact: ArtifactRef | null = null;
const lastAssistantIdx = (() => {
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === "assistant") return i;
}
return -1;
})();
for (let i = 0; i < messages.length; i++) {
const message = messages[i];
if (message.role !== "assistant") continue;
const isTailAssistant = i === lastAssistantIdx;
const isNewMessage = !previous.has(message.id);
if (!isTailAssistant && !isNewMessage) continue;
const artifacts = getMessageArtifacts(message);
const fingerprint = fingerprintArtifacts(artifacts);
nextFingerprints.set(message.id, fingerprint);
if (!hasInitializedRef.current || fingerprint.length === 0) {
continue;
}
const previousFingerprint = previous.get(message.id) ?? "";
if (previousFingerprint === fingerprint) continue;
nextArtifact = artifacts[artifacts.length - 1] ?? nextArtifact;
}
// Drop entries for messages that no longer exist (e.g. history truncated).
const liveIds = new Set(messages.map((m) => m.id));
for (const id of nextFingerprints.keys()) {
if (!liveIds.has(id)) nextFingerprints.delete(id);
}
messageFingerprintsRef.current = nextFingerprints;
if (!hasInitializedRef.current) {
hasInitializedRef.current = true;
return;
}
if (nextArtifact) {
openArtifact(nextArtifact);
}
}, [messages, openArtifact]);
}, [sessionId, resetArtifactPanel]);
}

View File

@@ -19,8 +19,16 @@ describe("formatResetTime", () => {
});
it("returns formatted date when over 24 hours away", () => {
const result = formatResetTime("2025-06-17T00:00:00Z", now);
expect(result).toMatch(/Tue/);
const resetsAt = "2025-06-17T00:00:00Z";
const result = formatResetTime(resetsAt, now);
const expected = new Date(resetsAt).toLocaleString(undefined, {
weekday: "short",
hour: "numeric",
minute: "2-digit",
timeZoneName: "short",
});
expect(result).toBe(expected);
});
it("accepts a Date object for resetsAt", () => {

View File

@@ -99,6 +99,50 @@ describe("artifactPanel store actions", () => {
expect(s.history).toEqual([]);
});
it("openArtifact does not resurrect a previously closed artifact into history", () => {
const a = makeArtifact("a");
const b = makeArtifact("b");
useCopilotUIStore.getState().openArtifact(a);
useCopilotUIStore.getState().closeArtifactPanel();
useCopilotUIStore.getState().openArtifact(b);
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(true);
expect(s.activeArtifact?.id).toBe("b");
expect(s.history).toEqual([]);
});
it("openArtifact ignores non-previewable artifacts", () => {
const binary = {
...makeArtifact("bin", "artifact.bin"),
mimeType: "application/octet-stream",
};
useCopilotUIStore.getState().openArtifact(binary);
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
it("resetArtifactPanel clears active artifact and history", () => {
const a = makeArtifact("a");
const b = makeArtifact("b");
useCopilotUIStore.getState().openArtifact(a);
useCopilotUIStore.getState().openArtifact(b);
useCopilotUIStore.getState().maximizeArtifactPanel();
useCopilotUIStore.getState().resetArtifactPanel();
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.isMinimized).toBe(false);
expect(s.isMaximized).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
});
it("minimize/restore toggles isMinimized without touching activeArtifact", () => {
const a = makeArtifact("a");
useCopilotUIStore.getState().openArtifact(a);
@@ -138,4 +182,35 @@ describe("artifactPanel store actions", () => {
expect(s.width).toBe(720);
expect(s.isMaximized).toBe(false);
});
it("history is capped at 25 entries (MAX_HISTORY)", () => {
// Open 27 artifacts sequentially (A0..A26). History should never exceed 25.
for (let i = 0; i < 27; i++) {
useCopilotUIStore.getState().openArtifact(makeArtifact(`a${i}`));
}
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.activeArtifact?.id).toBe("a26");
expect(s.history.length).toBe(25);
// The oldest entry (a0) should have been dropped; a1 is the earliest surviving.
expect(s.history[0].id).toBe("a1");
expect(s.history[24].id).toBe("a25");
});
it("clearCopilotLocalData resets artifact panel to default", () => {
const a = makeArtifact("a");
const b = makeArtifact("b");
useCopilotUIStore.getState().openArtifact(a);
useCopilotUIStore.getState().openArtifact(b);
useCopilotUIStore.getState().maximizeArtifactPanel();
useCopilotUIStore.getState().clearCopilotLocalData();
const s = useCopilotUIStore.getState().artifactPanel;
expect(s.isOpen).toBe(false);
expect(s.isMinimized).toBe(false);
expect(s.isMaximized).toBe(false);
expect(s.activeArtifact).toBeNull();
expect(s.history).toEqual([]);
expect(s.width).toBe(600); // DEFAULT_PANEL_WIDTH
});
});

View File

@@ -1,6 +1,7 @@
import { Key, storage } from "@/services/storage/local-storage";
import { create } from "zustand";
import { clearContentCache } from "./components/ArtifactPanel/components/useArtifactContent";
import { classifyArtifact } from "./components/ArtifactPanel/helpers";
import { ORIGINAL_TITLE, parseSessionIDs } from "./helpers";
export interface DeleteTarget {
@@ -92,6 +93,10 @@ function persistCompletedSessions(ids: Set<string>) {
}
}
function isPreviewableArtifact(ref: ArtifactRef): boolean {
return classifyArtifact(ref.mimeType, ref.title, ref.sizeBytes).openable;
}
interface CopilotUIState {
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
initialPrompt: string | null;
@@ -121,6 +126,7 @@ interface CopilotUIState {
artifactPanel: ArtifactPanelState;
openArtifact: (ref: ArtifactRef) => void;
closeArtifactPanel: () => void;
resetArtifactPanel: () => void;
minimizeArtifactPanel: () => void;
maximizeArtifactPanel: () => void;
restoreArtifactPanel: () => void;
@@ -203,14 +209,20 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
},
openArtifact: (ref) =>
set((state) => {
if (!isPreviewableArtifact(ref)) return state;
const { activeArtifact, history: prevHistory } = state.artifactPanel;
const topOfHistory = prevHistory[prevHistory.length - 1];
const isReturningToTop = topOfHistory?.id === ref.id;
const shouldPushHistory =
state.artifactPanel.isOpen &&
activeArtifact != null &&
activeArtifact.id !== ref.id;
const MAX_HISTORY = 25;
const history = isReturningToTop
? prevHistory.slice(0, -1)
: activeArtifact && activeArtifact.id !== ref.id
? [...prevHistory, activeArtifact].slice(-MAX_HISTORY)
: shouldPushHistory
? [...prevHistory, activeArtifact!].slice(-MAX_HISTORY)
: prevHistory;
return {
artifactPanel: {
@@ -231,6 +243,17 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
history: [],
},
})),
resetArtifactPanel: () =>
set((state) => ({
artifactPanel: {
...state.artifactPanel,
isOpen: false,
isMinimized: false,
isMaximized: false,
activeArtifact: null,
history: [],
},
})),
minimizeArtifactPanel: () =>
set((state) => ({
artifactPanel: { ...state.artifactPanel, isMinimized: true },

View File

@@ -1,15 +1,13 @@
"use client";
import React, { useState } from "react";
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import { Button } from "@/components/atoms/Button/Button";
import type { BlockOutputResponse } from "@/app/api/__generated__/models/blockOutputResponse";
import {
globalRegistry,
OutputItem,
} from "@/components/contextual/OutputRenderers";
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
import { isWorkspaceURI, parseWorkspaceURI } from "@/lib/workspace-uri";
import { resolveForRenderer } from "@/app/(platform)/copilot/tools/ViewAgentOutput/ViewAgentOutput";
import {
ContentBadge,
ContentCard,
@@ -24,28 +22,6 @@ interface Props {
const COLLAPSED_LIMIT = 3;
function resolveForRenderer(value: unknown): {
value: unknown;
metadata?: OutputMetadata;
} {
if (!isWorkspaceURI(value)) return { value };
const parsed = parseWorkspaceURI(value);
if (!parsed) return { value };
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
const url = `/api/proxy${apiPath}`;
const metadata: OutputMetadata = {};
if (parsed.mimeType) {
metadata.mimeType = parsed.mimeType;
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
}
return { value: url, metadata };
}
function RenderOutputValue({ value }: { value: unknown }) {
const resolved = resolveForRenderer(value);
const renderer = globalRegistry.getRenderer(
@@ -63,16 +39,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
);
}
// Fallback for audio workspace refs
if (
isWorkspaceURI(value) &&
resolved.metadata?.mimeType?.startsWith("audio/")
) {
return (
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
);
}
return null;
}

View File

@@ -2,7 +2,6 @@
import type { ToolUIPart } from "ai";
import React from "react";
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import {
globalRegistry,
OutputItem,
@@ -47,7 +46,7 @@ interface Props {
part: ViewAgentOutputToolPart;
}
function resolveForRenderer(value: unknown): {
export function resolveForRenderer(value: unknown): {
value: unknown;
metadata?: OutputMetadata;
} {
@@ -56,17 +55,17 @@ function resolveForRenderer(value: unknown): {
const parsed = parseWorkspaceURI(value);
if (!parsed) return { value };
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
const url = `/api/proxy${apiPath}`;
// Pass workspace URIs through to the registry unchanged.
// WorkspaceFileRenderer (priority 50) matches workspace:// URIs and
// handles URL building, loading skeletons, and error states internally.
// Previously this converted to a proxy URL which bypassed
// WorkspaceFileRenderer, causing ImageRenderer (bare <img>) to match.
const metadata: OutputMetadata = {};
if (parsed.mimeType) {
metadata.mimeType = parsed.mimeType;
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
}
return { value: url, metadata };
return { value, metadata };
}
function RenderOutputValue({ value }: { value: unknown }) {
@@ -86,16 +85,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
);
}
// Fallback for audio workspace refs
if (
isWorkspaceURI(value) &&
resolved.metadata?.mimeType?.startsWith("audio/")
) {
return (
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
);
}
return null;
}

View File

@@ -0,0 +1,52 @@
import { describe, expect, it } from "vitest";
import { resolveForRenderer } from "../ViewAgentOutput";
import { globalRegistry } from "@/components/contextual/OutputRenderers";
describe("resolveForRenderer", () => {
it("preserves workspace image URI for the registry to handle", () => {
const result = resolveForRenderer("workspace://abc123#image/png");
expect(String(result.value)).toMatch(/^workspace:\/\//);
expect(result.metadata?.mimeType).toBe("image/png");
});
it("preserves workspace video URI for the registry to handle", () => {
const result = resolveForRenderer("workspace://vid456#video/mp4");
expect(String(result.value)).toMatch(/^workspace:\/\//);
expect(result.metadata?.mimeType).toBe("video/mp4");
});
it("passes non-workspace values through unchanged", () => {
const result = resolveForRenderer("just a string");
expect(result.value).toBe("just a string");
expect(result.metadata).toBeUndefined();
});
it("passes non-string values through unchanged", () => {
const obj = { foo: "bar" };
const result = resolveForRenderer(obj);
expect(result.value).toBe(obj);
expect(result.metadata).toBeUndefined();
});
it("workspace image URIs match WorkspaceFileRenderer with loading/error states", () => {
// WorkspaceFileRenderer (priority 50) should handle workspace:// URIs
// since resolveForRenderer no longer pre-converts them to proxy URLs.
const resolved = resolveForRenderer("workspace://abc123#image/png");
const renderer = globalRegistry.getRenderer(
resolved.value,
resolved.metadata,
);
expect(renderer).toBeDefined();
expect(renderer!.name).toBe("WorkspaceFileRenderer");
});
it("workspace video URIs match WorkspaceFileRenderer", () => {
const resolved = resolveForRenderer("workspace://vid456#video/mp4");
const renderer = globalRegistry.getRenderer(
resolved.value,
resolved.metadata,
);
expect(renderer).toBeDefined();
expect(renderer!.name).toBe("WorkspaceFileRenderer");
});
});

View File

@@ -0,0 +1,96 @@
import {
getGetV2GetSpecificAgentMockHandler,
getGetV2GetSpecificAgentResponseMock,
getGetV2ListStoreAgentsMockHandler,
getGetV2ListStoreAgentsResponseMock,
} from "@/app/api/__generated__/endpoints/store/store.msw";
import { server } from "@/mocks/mock-server";
import { render, screen } from "@/tests/integrations/test-utils";
import { MainAgentPage } from "../MainAgentPage";
import { beforeEach, describe, expect, test, vi } from "vitest";
const mockUseSupabase = vi.hoisted(() => vi.fn());
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: mockUseSupabase,
}));
describe("MainAgentPage", () => {
beforeEach(() => {
mockUseSupabase.mockReturnValue({
user: null,
});
});
test("renders the marketplace agent details and related sections", async () => {
const agentDetails = getGetV2GetSpecificAgentResponseMock({
agent_name: "Deterministic Agent",
creator: "AutoGPT",
creator_avatar: "",
sub_heading: "A stable marketplace listing",
description: "This agent is used for integration coverage.",
categories: ["demo", "test"],
versions: ["1", "2"],
active_version_id: "store-version-1",
store_listing_version_id: "listing-1",
agent_image: ["https://example.com/agent.png"],
agent_output_demo: "",
agent_video: "",
});
const otherAgents = getGetV2ListStoreAgentsResponseMock({
agents: [
{
...getGetV2ListStoreAgentsResponseMock().agents[0],
slug: "other-agent",
agent_name: "Other Agent",
creator: "AutoGPT",
},
],
});
const similarAgents = getGetV2ListStoreAgentsResponseMock({
agents: [
{
...getGetV2ListStoreAgentsResponseMock().agents[0],
slug: "similar-agent",
agent_name: "Similar Agent",
creator: "Another Creator",
},
],
});
server.use(
getGetV2GetSpecificAgentMockHandler(agentDetails),
getGetV2ListStoreAgentsMockHandler(({ request }) => {
const url = new URL(request.url);
if (url.searchParams.get("creator") === "autogpt") {
return otherAgents;
}
if (url.searchParams.get("search_query") === "deterministic agent") {
return similarAgents;
}
return getGetV2ListStoreAgentsResponseMock({ agents: [] });
}),
);
render(
<MainAgentPage
params={{ creator: "autogpt", slug: "deterministic-agent" }}
/>,
);
expect((await screen.findByTestId("agent-title")).textContent).toContain(
"Deterministic Agent",
);
expect(screen.getByTestId("agent-description").textContent).toContain(
"This agent is used for integration coverage.",
);
expect(screen.getByTestId("agent-creator").textContent).toContain(
"AutoGPT",
);
expect(screen.getByText("Other agents by AutoGPT")).toBeDefined();
expect(screen.getByText("Similar agents")).toBeDefined();
});
});

View File

@@ -1,15 +1,64 @@
import { expect, test } from "vitest";
import {
getGetV2ListStoreAgentsResponseMock,
getGetV2ListStoreCreatorsResponseMock,
} from "@/app/api/__generated__/endpoints/store/store.msw";
import { render, screen } from "@/tests/integrations/test-utils";
import { MainMarkeplacePage } from "../MainMarketplacePage";
import { server } from "@/mocks/mock-server";
import { getDeleteV2DeleteStoreSubmissionMockHandler422 } from "@/app/api/__generated__/endpoints/store/store.msw";
import { beforeEach, describe, expect, test, vi } from "vitest";
// Only for CI testing purpose, will remove it in future PR
test("MainMarketplacePage", async () => {
server.use(getDeleteV2DeleteStoreSubmissionMockHandler422());
const mockUseMainMarketplacePage = vi.hoisted(() => vi.fn());
render(<MainMarkeplacePage />);
expect(
await screen.findByText("Featured agents", { exact: false }),
).toBeDefined();
vi.mock("../useMainMarketplacePage", () => ({
useMainMarketplacePage: mockUseMainMarketplacePage,
}));
describe("MainMarketplacePage", () => {
beforeEach(() => {
mockUseMainMarketplacePage.mockReturnValue({
featuredAgents: getGetV2ListStoreAgentsResponseMock({
agents: [
{
...getGetV2ListStoreAgentsResponseMock().agents[0],
slug: "featured-agent",
agent_name: "Featured Agent",
creator: "AutoGPT",
},
],
}),
topAgents: getGetV2ListStoreAgentsResponseMock({
agents: [
{
...getGetV2ListStoreAgentsResponseMock().agents[0],
slug: "top-agent",
agent_name: "Top Agent",
creator: "AutoGPT",
},
],
}),
featuredCreators: getGetV2ListStoreCreatorsResponseMock({
creators: [
{
...getGetV2ListStoreCreatorsResponseMock().creators[0],
name: "Creator One",
username: "creator-one",
},
],
}),
isLoading: false,
hasError: false,
});
});
test("renders featured agents, all agents, and creators", () => {
render(<MainMarkeplacePage />);
expect(screen.getByText(/Featured agents/i)).toBeDefined();
expect(screen.getByText("Featured Agent")).toBeDefined();
expect(screen.getByText("All Agents")).toBeDefined();
expect(screen.getAllByText("Top Agent").length).toBeGreaterThan(0);
expect(screen.getByText("Creator One")).toBeDefined();
expect(
screen.getByRole("button", { name: "Become a Creator" }),
).toBeDefined();
});
});

View File

@@ -0,0 +1,57 @@
import { render, screen } from "@/tests/integrations/test-utils";
import {
getGetV2GetCreatorDetailsResponseMock,
getGetV2ListStoreAgentsResponseMock,
} from "@/app/api/__generated__/endpoints/store/store.msw";
import { MainCreatorPage } from "../MainCreatorPage";
import { beforeEach, describe, expect, test, vi } from "vitest";
const mockUseMainCreatorPage = vi.hoisted(() => vi.fn());
vi.mock("../useMainCreatorPage", () => ({
useMainCreatorPage: mockUseMainCreatorPage,
}));
describe("MainCreatorPage", () => {
beforeEach(() => {
const creator = getGetV2GetCreatorDetailsResponseMock({
name: "Creator One",
username: "creator-one",
description: "Creator profile used for integration coverage.",
avatar_url: "",
top_categories: ["automation", "productivity"],
links: ["https://example.com/creator"],
});
const creatorAgents = getGetV2ListStoreAgentsResponseMock({
agents: [
{
...getGetV2ListStoreAgentsResponseMock().agents[0],
slug: "creator-agent",
agent_name: "Creator Agent",
creator: "Creator One",
},
],
});
mockUseMainCreatorPage.mockReturnValue({
creatorAgents,
creator,
isLoading: false,
hasError: false,
});
});
test("renders creator details and their agents", () => {
render(<MainCreatorPage params={{ creator: "creator-one" }} />);
expect(screen.getByTestId("creator-title").textContent).toContain(
"Creator One",
);
expect(screen.getByTestId("creator-description").textContent).toContain(
"Creator profile used for integration coverage.",
);
expect(screen.getByText("Agents by Creator One")).toBeDefined();
expect(screen.getAllByText("Creator Agent").length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,83 @@
import type { ReactNode } from "react";
import {
render,
screen,
fireEvent,
waitFor,
} from "@/tests/integrations/test-utils";
import {
getGetV2GetUserProfileMockHandler,
getPostV2UpdateUserProfileMockHandler,
} from "@/app/api/__generated__/endpoints/store/store.msw";
import { server } from "@/mocks/mock-server";
import UserProfilePage from "../page";
import { beforeEach, describe, expect, test, vi } from "vitest";
const mockUseSupabase = vi.hoisted(() => vi.fn());
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
default: ({ children }: { children: ReactNode }) => <>{children}</>,
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: mockUseSupabase,
}));
const testUser = {
id: "user-1",
email: "user@example.com",
app_metadata: {},
user_metadata: {},
aud: "authenticated",
created_at: "2026-01-01T00:00:00.000Z",
};
describe("UserProfilePage", () => {
beforeEach(() => {
mockUseSupabase.mockReturnValue({
user: testUser,
isLoggedIn: true,
isUserLoading: false,
supabase: {},
});
});
test("renders the existing profile and saves changes", async () => {
let profile = {
name: "Original Name",
username: "original-user",
description: "Original bio",
links: ["https://example.com/1", "", "", "", ""],
avatar_url: "",
is_featured: false,
};
server.use(
getGetV2GetUserProfileMockHandler(() => profile),
getPostV2UpdateUserProfileMockHandler(async ({ request }) => {
profile = (await request.json()) as typeof profile;
return profile;
}),
);
render(<UserProfilePage />);
const displayName = await screen.findByLabelText("Display name");
const handle = screen.getByLabelText("Handle");
const bio = screen.getByLabelText("Bio");
expect((displayName as HTMLInputElement).value).toBe("Original Name");
expect((handle as HTMLInputElement).value).toBe("original-user");
fireEvent.change(displayName, { target: { value: "Updated Name" } });
fireEvent.change(handle, { target: { value: "updated-user" } });
fireEvent.change(bio, { target: { value: "Updated bio" } });
fireEvent.click(screen.getByRole("button", { name: "Save changes" }));
await waitFor(() => {
expect(profile.name).toBe("Updated Name");
expect(profile.username).toBe("updated-user");
expect(profile.description).toBe("Updated bio");
});
});
});

View File

@@ -0,0 +1,138 @@
import {
fireEvent,
render,
screen,
waitFor,
} from "@/tests/integrations/test-utils";
import {
getDeleteV1RevokeApiKeyMockHandler,
getGetV1ListUserApiKeysMockHandler,
getPostV1CreateNewApiKeyMockHandler,
} from "@/app/api/__generated__/endpoints/api-keys/api-keys.msw";
import { APIKeyPermission } from "@/app/api/__generated__/models/aPIKeyPermission";
import { APIKeyStatus } from "@/app/api/__generated__/models/aPIKeyStatus";
import { server } from "@/mocks/mock-server";
import ApiKeysPage from "../page";
import { beforeEach, describe, expect, test } from "vitest";
type ApiKeyRecord = {
id: string;
name: string;
head: string;
tail: string;
status: APIKeyStatus;
};
function toApiKeyResponse(key: ApiKeyRecord) {
return {
id: key.id,
user_id: "user-1",
scopes: [APIKeyPermission.EXECUTE_GRAPH],
type: "api_key" as const,
created_at: new Date("2026-01-01T00:00:00.000Z"),
expires_at: null,
last_used_at: null,
revoked_at: null,
name: key.name,
head: key.head,
tail: key.tail,
status: key.status,
description: null,
};
}
describe("ApiKeysPage", () => {
let apiKeys: ApiKeyRecord[];
let revokedKeyId: string;
beforeEach(() => {
apiKeys = [];
revokedKeyId = "";
server.use(
getGetV1ListUserApiKeysMockHandler(() =>
apiKeys.map((key) => toApiKeyResponse(key)),
),
getPostV1CreateNewApiKeyMockHandler(async ({ request }) => {
const body = (await request.json()) as {
name: string;
description?: string;
permissions?: APIKeyPermission[];
};
const createdKey: ApiKeyRecord = {
id: `key-${apiKeys.length + 1}`,
name: body.name,
head: "head",
tail: "tail",
status: APIKeyStatus.ACTIVE,
};
apiKeys = [...apiKeys, createdKey];
return {
api_key: toApiKeyResponse(createdKey),
plain_text_key: "plain-text-key",
};
}),
getDeleteV1RevokeApiKeyMockHandler(({ params }) => {
const keyId = String(params.keyId);
const removedKey = apiKeys.find((key) => key.id === keyId);
revokedKeyId = keyId;
apiKeys = apiKeys.filter((key) => key.id !== keyId);
return toApiKeyResponse(
removedKey ?? {
id: keyId,
name: "Unknown key",
head: "head",
tail: "tail",
status: APIKeyStatus.REVOKED,
},
);
}),
);
});
test("creates a new API key", async () => {
render(<ApiKeysPage />);
fireEvent.click(await screen.findByText("Create Key"));
fireEvent.change(screen.getByLabelText("Name"), {
target: { value: "CLI Key" },
});
fireEvent.click(screen.getByText("Create"));
expect(
await screen.findByText("AutoGPT Platform API Key Created"),
).toBeDefined();
await waitFor(() => {
expect(apiKeys[0]?.name).toBe("CLI Key");
});
});
test("revokes an existing API key", async () => {
apiKeys = [
{
id: "key-1",
name: "Existing Key",
head: "head",
tail: "tail",
status: APIKeyStatus.ACTIVE,
},
];
render(<ApiKeysPage />);
expect(await screen.findByText("Existing Key")).toBeDefined();
fireEvent.pointerDown(screen.getByTestId("api-key-actions"));
fireEvent.click(await screen.findByRole("menuitem", { name: "Revoke" }));
await waitFor(() => {
expect(revokedKeyId).toBe("key-1");
});
});
});

View File

@@ -0,0 +1,76 @@
import { render, screen, fireEvent } from "@testing-library/react";
import { getGetV2ListMySubmissionsResponseMock } from "@/app/api/__generated__/endpoints/store/store.msw";
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
import { AgentTableRow } from "../AgentTableRow";
import { beforeEach, describe, expect, test, vi } from "vitest";
function makeSubmission(status: SubmissionStatus) {
const submission = getGetV2ListMySubmissionsResponseMock().submissions[0];
return {
...submission,
graph_id: "graph-1",
graph_version: 7,
listing_version_id: `listing-${status.toLowerCase()}`,
name: `Agent ${status}`,
description: `Description ${status}`,
status,
image_urls: [],
submitted_at: new Date("2026-01-01T00:00:00.000Z"),
};
}
describe("AgentTableRow", () => {
const onViewSubmission = vi.fn();
const onDeleteSubmission = vi.fn();
const onEditSubmission = vi.fn();
beforeEach(() => {
onViewSubmission.mockReset();
onDeleteSubmission.mockReset();
onEditSubmission.mockReset();
});
test("shows edit and delete actions for pending submissions", async () => {
render(
<AgentTableRow
storeAgentSubmission={makeSubmission(SubmissionStatus.PENDING)}
onViewSubmission={onViewSubmission}
onDeleteSubmission={onDeleteSubmission}
onEditSubmission={onEditSubmission}
/>,
);
fireEvent.pointerDown(screen.getByTestId("agent-table-row-actions"));
fireEvent.click(await screen.findByText("Edit"));
expect(onEditSubmission).toHaveBeenCalledTimes(1);
fireEvent.pointerDown(screen.getByTestId("agent-table-row-actions"));
fireEvent.click(await screen.findByText("Delete"));
expect(onDeleteSubmission).toHaveBeenCalledWith("listing-pending");
expect(onViewSubmission).not.toHaveBeenCalled();
});
test("shows view only for non-pending submissions", async () => {
const approvedSubmission = makeSubmission(SubmissionStatus.APPROVED);
render(
<AgentTableRow
storeAgentSubmission={approvedSubmission}
onViewSubmission={onViewSubmission}
onDeleteSubmission={onDeleteSubmission}
onEditSubmission={onEditSubmission}
/>,
);
fireEvent.pointerDown(screen.getByTestId("agent-table-row-actions"));
const viewAction = await screen.findByText("View");
fireEvent.click(viewAction);
expect(onViewSubmission).toHaveBeenCalledWith(approvedSubmission);
expect(screen.queryByText("Edit")).toBeNull();
expect(screen.queryByText("Delete")).toBeNull();
});
});

View File

@@ -0,0 +1,147 @@
import type { ReactNode } from "react";
import {
render,
screen,
fireEvent,
waitFor,
} from "@/tests/integrations/test-utils";
import {
getGetV1GetNotificationPreferencesMockHandler,
getGetV1GetUserTimezoneMockHandler,
getPostV1UpdateNotificationPreferencesMockHandler,
getPostV1UpdateUserEmailMockHandler,
} from "@/app/api/__generated__/endpoints/auth/auth.msw";
import { server } from "@/mocks/mock-server";
import SettingsPage from "../page";
import { beforeEach, describe, expect, test, vi } from "vitest";
const mockUseSupabase = vi.hoisted(() => vi.fn());
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
default: ({ children }: { children: ReactNode }) => <>{children}</>,
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: mockUseSupabase,
}));
const testUser = {
id: "user-1",
email: "user@example.com",
app_metadata: {},
user_metadata: {},
aud: "authenticated",
created_at: "2026-01-01T00:00:00.000Z",
};
describe("SettingsPage", () => {
beforeEach(() => {
mockUseSupabase.mockReturnValue({
user: testUser,
isLoggedIn: true,
isUserLoading: false,
supabase: {},
});
});
test("renders the account actions", async () => {
server.use(
getGetV1GetNotificationPreferencesMockHandler({
user_id: "user-1",
email: "user@example.com",
preferences: {
AGENT_RUN: true,
ZERO_BALANCE: false,
LOW_BALANCE: false,
BLOCK_EXECUTION_FAILED: true,
CONTINUOUS_AGENT_ERROR: false,
DAILY_SUMMARY: false,
WEEKLY_SUMMARY: true,
MONTHLY_SUMMARY: false,
AGENT_APPROVED: true,
AGENT_REJECTED: true,
},
daily_limit: 0,
emails_sent_today: 0,
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
}),
getGetV1GetUserTimezoneMockHandler({ timezone: "Asia/Kolkata" }),
getPostV1UpdateUserEmailMockHandler({}),
getPostV1UpdateNotificationPreferencesMockHandler({
user_id: "user-1",
email: "user@example.com",
preferences: {},
daily_limit: 0,
emails_sent_today: 0,
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
}),
);
render(<SettingsPage />);
const emailInput = await screen.findByLabelText("Email");
expect((emailInput as HTMLInputElement).value).toBe("user@example.com");
expect(
screen.getByRole("link", { name: "Reset password" }).getAttribute("href"),
).toBe("/reset-password");
});
test("saves notification preference changes", async () => {
let submittedPreferences:
| {
email: string;
preferences: Record<string, boolean>;
}
| undefined;
server.use(
getGetV1GetNotificationPreferencesMockHandler({
user_id: "user-1",
email: "user@example.com",
preferences: {
AGENT_RUN: false,
ZERO_BALANCE: false,
LOW_BALANCE: false,
BLOCK_EXECUTION_FAILED: false,
CONTINUOUS_AGENT_ERROR: false,
DAILY_SUMMARY: false,
WEEKLY_SUMMARY: false,
MONTHLY_SUMMARY: false,
AGENT_APPROVED: false,
AGENT_REJECTED: false,
},
daily_limit: 0,
emails_sent_today: 0,
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
}),
getGetV1GetUserTimezoneMockHandler({ timezone: "Asia/Kolkata" }),
getPostV1UpdateUserEmailMockHandler({}),
getPostV1UpdateNotificationPreferencesMockHandler(async ({ request }) => {
submittedPreferences = (await request.json()) as {
email: string;
preferences: Record<string, boolean>;
};
return {
user_id: "user-1",
email: submittedPreferences.email,
preferences: submittedPreferences.preferences,
daily_limit: 0,
emails_sent_today: 0,
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
};
}),
);
render(<SettingsPage />);
fireEvent.click(
await screen.findByRole("switch", { name: "Agent Run Notifications" }),
);
fireEvent.click(screen.getByRole("button", { name: "Save preferences" }));
await waitFor(() => {
expect(submittedPreferences?.preferences.AGENT_RUN).toBe(true);
});
});
});

View File

@@ -0,0 +1,97 @@
import {
fireEvent,
render,
screen,
waitFor,
} from "@/tests/integrations/test-utils";
import type { ReactNode } from "react";
import type { User } from "@supabase/supabase-js";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { EmailForm } from "../EmailForm";
const mockToast = vi.hoisted(() => vi.fn());
const mockMutateAsync = vi.hoisted(() => vi.fn());
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
vi.mock("@/app/api/__generated__/endpoints/auth/auth", () => ({
usePostV1UpdateUserEmail: () => ({
mutateAsync: mockMutateAsync,
isPending: false,
}),
}));
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
default: ({ children }: { children: ReactNode }) => <>{children}</>,
}));
const testUser = {
id: "user-1",
email: "user@example.com",
app_metadata: {},
user_metadata: {},
aud: "authenticated",
created_at: "2026-01-01T00:00:00.000Z",
} as User;
describe("EmailForm", () => {
beforeEach(() => {
mockToast.mockReset();
mockMutateAsync.mockReset();
mockMutateAsync.mockResolvedValue({});
});
afterEach(() => {
vi.unstubAllGlobals();
});
test("submits a changed email to both update endpoints", async () => {
const fetchMock = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({}),
});
vi.stubGlobal("fetch", fetchMock);
render(<EmailForm user={testUser} />);
fireEvent.change(screen.getByLabelText("Email"), {
target: { value: "updated@example.com" },
});
fireEvent.click(screen.getByRole("button", { name: "Update email" }));
await waitFor(() => {
expect(fetchMock).toHaveBeenCalledWith("/api/auth/user", {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email: "updated@example.com" }),
});
});
await waitFor(() => {
expect(mockMutateAsync).toHaveBeenCalledWith({
data: "updated@example.com",
});
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Successfully updated email",
}),
);
});
test("keeps submit disabled when the email has not changed", () => {
render(<EmailForm user={testUser} />);
expect(
(
screen.getByRole("button", {
name: "Update email",
}) as HTMLButtonElement
).disabled,
).toBe(true);
});
});

View File

@@ -55,6 +55,7 @@ export function NotificationForm({ preferences, user }: NotificationFormProps) {
</div>
<FormControl>
<Switch
aria-label="Agent Run Notifications"
checked={field.value}
onCheckedChange={field.onChange}
/>

View File

@@ -0,0 +1,73 @@
import type { ReactNode } from "react";
import {
render,
screen,
fireEvent,
waitFor,
} from "@/tests/integrations/test-utils";
import SignupPage from "../page";
import { beforeEach, describe, expect, test, vi } from "vitest";
const mockUseSupabase = vi.hoisted(() => vi.fn());
const mockSignupAction = vi.hoisted(() => vi.fn());
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
default: ({ children }: { children: ReactNode }) => <>{children}</>,
}));
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
useSupabase: mockUseSupabase,
}));
vi.mock("../actions", () => ({
signup: mockSignupAction,
}));
describe("SignupPage", () => {
beforeEach(() => {
mockUseSupabase.mockReturnValue({
supabase: {},
user: null,
isUserLoading: false,
isLoggedIn: false,
});
mockSignupAction.mockReset();
});
test("shows existing user feedback from signup action", async () => {
mockSignupAction.mockResolvedValue({
success: false,
error: "user_already_exists",
});
render(<SignupPage />);
fireEvent.change(screen.getByLabelText("Email"), {
target: { value: "existing@example.com" },
});
fireEvent.change(screen.getByLabelText("Password", { selector: "input" }), {
target: { value: "validpassword123" },
});
fireEvent.change(
screen.getByLabelText("Confirm Password", { selector: "input" }),
{
target: { value: "validpassword123" },
},
);
fireEvent.click(screen.getByRole("checkbox"));
fireEvent.click(screen.getByRole("button", { name: "Sign up" }));
await waitFor(() => {
expect(mockSignupAction).toHaveBeenCalledWith(
"existing@example.com",
"validpassword123",
"validpassword123",
true,
);
});
expect(
await screen.findByText("User with this email already exists"),
).toBeDefined();
});
});

View File

@@ -1,25 +0,0 @@
/**
* Generated by orval v7.13.0 🍺
* Do not edit manually.
* AutoGPT Agent Server
* This server is used to execute agents that are created by the AutoGPT system.
* OpenAPI spec version: 0.1
*/
import type { ResponseType } from "./responseType";
import type { BlockOutputResponseSessionId } from "./blockOutputResponseSessionId";
import type { BlockOutputResponseOutputs } from "./blockOutputResponseOutputs";
import type { BlockOutputResponseIsDryRun } from "./blockOutputResponseIsDryRun";
/**
* Response for run_block tool.
*/
export interface BlockOutputResponse {
type?: ResponseType;
message: string;
session_id?: BlockOutputResponseSessionId;
block_id: string;
block_name: string;
outputs: BlockOutputResponseOutputs;
success?: boolean;
is_dry_run?: BlockOutputResponseIsDryRun;
}

View File

@@ -1,36 +0,0 @@
/**
* Generated by orval v7.13.0 🍺
* Do not edit manually.
* AutoGPT Agent Server
* This server is used to execute agents that are created by the AutoGPT system.
* OpenAPI spec version: 0.1
*/
import type { GraphExecutionMetaInputs } from "./graphExecutionMetaInputs";
import type { GraphExecutionMetaCredentialInputs } from "./graphExecutionMetaCredentialInputs";
import type { GraphExecutionMetaNodesInputMasks } from "./graphExecutionMetaNodesInputMasks";
import type { GraphExecutionMetaPresetId } from "./graphExecutionMetaPresetId";
import type { AgentExecutionStatus } from "./agentExecutionStatus";
import type { GraphExecutionMetaStartedAt } from "./graphExecutionMetaStartedAt";
import type { GraphExecutionMetaEndedAt } from "./graphExecutionMetaEndedAt";
import type { GraphExecutionMetaShareToken } from "./graphExecutionMetaShareToken";
import type { GraphExecutionMetaStats } from "./graphExecutionMetaStats";
export interface GraphExecutionMeta {
id: string;
user_id: string;
graph_id: string;
graph_version: number;
inputs: GraphExecutionMetaInputs;
credential_inputs: GraphExecutionMetaCredentialInputs;
nodes_input_masks: GraphExecutionMetaNodesInputMasks;
preset_id: GraphExecutionMetaPresetId;
status: AgentExecutionStatus;
/** When execution started running. Null if not yet started (QUEUED). */
started_at?: GraphExecutionMetaStartedAt;
/** When execution finished. Null if not yet completed (QUEUED, RUNNING, INCOMPLETE, REVIEW). */
ended_at?: GraphExecutionMetaEndedAt;
is_shared?: boolean;
share_token?: GraphExecutionMetaShareToken;
is_dry_run?: boolean;
stats: GraphExecutionMetaStats;
}

View File

@@ -1,15 +0,0 @@
/**
* Generated by orval v7.13.0 🍺
* Do not edit manually.
* AutoGPT Agent Server
* This server is used to execute agents that are created by the AutoGPT system.
* OpenAPI spec version: 0.1
*/
import type { SuggestedTheme } from "./suggestedTheme";
/**
* Response model for user-specific suggested prompts grouped by theme.
*/
export interface SuggestedPromptsResponse {
themes: SuggestedTheme[];
}

View File

@@ -1,15 +0,0 @@
/**
* Generated by orval v7.13.0 🍺
* Do not edit manually.
* AutoGPT Agent Server
* This server is used to execute agents that are created by the AutoGPT system.
* OpenAPI spec version: 0.1
*/
/**
* A themed group of suggested prompts.
*/
export interface SuggestedTheme {
name: string;
prompts: string[];
}

View File

@@ -9123,6 +9123,15 @@
],
"title": "ContentType"
},
"CostBucket": {
"properties": {
"bucket": { "type": "string", "title": "Bucket" },
"count": { "type": "integer", "title": "Count" }
},
"type": "object",
"required": ["bucket", "count"],
"title": "CostBucket"
},
"CostLogRow": {
"properties": {
"id": { "type": "string", "title": "Id" },
@@ -12141,7 +12150,58 @@
"title": "Total Cost Microdollars"
},
"total_requests": { "type": "integer", "title": "Total Requests" },
"total_users": { "type": "integer", "title": "Total Users" }
"total_users": { "type": "integer", "title": "Total Users" },
"total_input_tokens": {
"type": "integer",
"title": "Total Input Tokens",
"default": 0
},
"total_output_tokens": {
"type": "integer",
"title": "Total Output Tokens",
"default": 0
},
"avg_input_tokens_per_request": {
"type": "number",
"title": "Avg Input Tokens Per Request",
"default": 0.0
},
"avg_output_tokens_per_request": {
"type": "number",
"title": "Avg Output Tokens Per Request",
"default": 0.0
},
"avg_cost_microdollars_per_request": {
"type": "number",
"title": "Avg Cost Microdollars Per Request",
"default": 0.0
},
"cost_p50_microdollars": {
"type": "number",
"title": "Cost P50 Microdollars",
"default": 0.0
},
"cost_p75_microdollars": {
"type": "number",
"title": "Cost P75 Microdollars",
"default": 0.0
},
"cost_p95_microdollars": {
"type": "number",
"title": "Cost P95 Microdollars",
"default": 0.0
},
"cost_p99_microdollars": {
"type": "number",
"title": "Cost P99 Microdollars",
"default": 0.0
},
"cost_buckets": {
"items": { "$ref": "#/components/schemas/CostBucket" },
"type": "array",
"title": "Cost Buckets",
"default": []
}
},
"type": "object",
"required": [
@@ -15585,7 +15645,12 @@
"title": "Total Cache Creation Tokens",
"default": 0
},
"request_count": { "type": "integer", "title": "Request Count" }
"request_count": { "type": "integer", "title": "Request Count" },
"cost_bearing_request_count": {
"type": "integer",
"title": "Cost Bearing Request Count",
"default": 0
}
},
"type": "object",
"required": [

View File

@@ -0,0 +1,282 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import {
isWorkspaceDownloadRequest,
isRedirectStatus,
isTransientWorkspaceDownloadStatus,
getWorkspaceDownloadErrorMessage,
fetchWorkspaceDownloadOnce,
fetchWorkspaceDownloadWithRetry,
} from "./route.helpers";
describe("isWorkspaceDownloadRequest", () => {
it("matches api/workspace/files/{id}/download pattern", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"abc-123",
"download",
]),
).toBe(true);
});
it("rejects paths with wrong segment count", () => {
expect(
isWorkspaceDownloadRequest(["api", "workspace", "files", "download"]),
).toBe(false);
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"id",
"download",
"extra",
]),
).toBe(false);
});
it("rejects paths with wrong prefix", () => {
expect(
isWorkspaceDownloadRequest([
"v1",
"workspace",
"files",
"id",
"download",
]),
).toBe(false);
});
it("rejects paths not ending with download", () => {
expect(
isWorkspaceDownloadRequest([
"api",
"workspace",
"files",
"id",
"metadata",
]),
).toBe(false);
});
});
describe("isRedirectStatus", () => {
it.each([301, 302, 303, 307, 308])("returns true for %d", (status) => {
expect(isRedirectStatus(status)).toBe(true);
});
it.each([200, 304, 400, 404, 500])("returns false for %d", (status) => {
expect(isRedirectStatus(status)).toBe(false);
});
});
describe("isTransientWorkspaceDownloadStatus", () => {
it.each([408, 429, 500, 502, 503, 504])(
"returns true for transient %d",
(status) => {
expect(isTransientWorkspaceDownloadStatus(status)).toBe(true);
},
);
it.each([400, 401, 403, 404, 405])(
"returns false for non-transient %d",
(status) => {
expect(isTransientWorkspaceDownloadStatus(status)).toBe(false);
},
);
});
describe("getWorkspaceDownloadErrorMessage", () => {
it("extracts detail string from object", () => {
expect(getWorkspaceDownloadErrorMessage({ detail: "Not found" })).toBe(
"Not found",
);
});
it("extracts error string from object", () => {
expect(getWorkspaceDownloadErrorMessage({ error: "Server error" })).toBe(
"Server error",
);
});
it("extracts nested detail.message", () => {
expect(
getWorkspaceDownloadErrorMessage({
detail: { message: "Nested error" },
}),
).toBe("Nested error");
});
it("returns trimmed string body", () => {
expect(getWorkspaceDownloadErrorMessage(" error text ")).toBe(
"error text",
);
});
it("returns null for empty string", () => {
expect(getWorkspaceDownloadErrorMessage("")).toBeNull();
});
it("returns null for whitespace-only string", () => {
expect(getWorkspaceDownloadErrorMessage(" ")).toBeNull();
});
it("returns null for null/undefined", () => {
expect(getWorkspaceDownloadErrorMessage(null)).toBeNull();
expect(getWorkspaceDownloadErrorMessage(undefined)).toBeNull();
});
it("returns null for object with empty detail", () => {
expect(getWorkspaceDownloadErrorMessage({ detail: "" })).toBeNull();
});
it("returns null for object with no recognized keys", () => {
expect(getWorkspaceDownloadErrorMessage({ foo: "bar" })).toBeNull();
});
it("prefers detail over error", () => {
expect(
getWorkspaceDownloadErrorMessage({
detail: "detail msg",
error: "error msg",
}),
).toBe("detail msg");
});
});
describe("fetchWorkspaceDownloadOnce", () => {
beforeEach(() => {
vi.stubGlobal("fetch", vi.fn());
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it("returns response directly for non-redirect status", async () => {
const mockResponse = { ok: true, status: 200, headers: new Headers() };
vi.mocked(fetch).mockResolvedValue(mockResponse as unknown as Response);
const result = await fetchWorkspaceDownloadOnce("https://backend/file", {});
expect(result).toBe(mockResponse);
expect(fetch).toHaveBeenCalledOnce();
});
it("follows redirect when Location header is present", async () => {
const redirectResponse = {
ok: false,
status: 302,
headers: new Headers({ Location: "https://storage.example.com/file" }),
};
const finalResponse = { ok: true, status: 200, headers: new Headers() };
vi.mocked(fetch)
.mockResolvedValueOnce(redirectResponse as unknown as Response)
.mockResolvedValueOnce(finalResponse as unknown as Response);
const result = await fetchWorkspaceDownloadOnce("https://backend/file", {
Authorization: "Bearer token",
});
expect(result).toBe(finalResponse);
expect(fetch).toHaveBeenCalledTimes(2);
expect(fetch).toHaveBeenNthCalledWith(
2,
"https://storage.example.com/file",
{ method: "GET", redirect: "follow" },
);
});
it("returns redirect response when Location header is missing", async () => {
const redirectResponse = {
ok: false,
status: 307,
headers: new Headers(),
};
vi.mocked(fetch).mockResolvedValue(redirectResponse as unknown as Response);
const result = await fetchWorkspaceDownloadOnce("https://backend/file", {});
expect(result).toBe(redirectResponse);
expect(fetch).toHaveBeenCalledOnce();
});
});
describe("fetchWorkspaceDownloadWithRetry", () => {
beforeEach(() => {
vi.stubGlobal("fetch", vi.fn());
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it("returns immediately on success", async () => {
const okResponse = { ok: true, status: 200, headers: new Headers() };
vi.mocked(fetch).mockResolvedValue(okResponse as unknown as Response);
const result = await fetchWorkspaceDownloadWithRetry(
"https://backend/file",
{},
2,
0,
);
expect(result).toBe(okResponse);
expect(fetch).toHaveBeenCalledOnce();
});
it("returns immediately on non-transient error without retrying", async () => {
const notFound = { ok: false, status: 404, headers: new Headers() };
vi.mocked(fetch).mockResolvedValue(notFound as unknown as Response);
const result = await fetchWorkspaceDownloadWithRetry(
"https://backend/file",
{},
2,
0,
);
expect(result.status).toBe(404);
expect(fetch).toHaveBeenCalledOnce();
});
it("retries on transient 502 and succeeds", async () => {
const bad = { ok: false, status: 502, headers: new Headers() };
const ok = { ok: true, status: 200, headers: new Headers() };
vi.mocked(fetch)
.mockResolvedValueOnce(bad as unknown as Response)
.mockResolvedValueOnce(ok as unknown as Response);
const result = await fetchWorkspaceDownloadWithRetry(
"https://backend/file",
{},
2,
0,
);
expect(result).toBe(ok);
expect(fetch).toHaveBeenCalledTimes(2);
});
it("returns last transient response after exhausting retries", async () => {
const bad = { ok: false, status: 503, headers: new Headers() };
vi.mocked(fetch).mockResolvedValue(bad as unknown as Response);
const result = await fetchWorkspaceDownloadWithRetry(
"https://backend/file",
{},
2,
0,
);
expect(result.status).toBe(503);
expect(fetch).toHaveBeenCalledTimes(3);
});
it("retries on network error and throws after exhausting retries", async () => {
vi.mocked(fetch).mockRejectedValue(new Error("Connection reset"));
await expect(
fetchWorkspaceDownloadWithRetry("https://backend/file", {}, 1, 0),
).rejects.toThrow("Connection reset");
expect(fetch).toHaveBeenCalledTimes(2);
});
});

View File

@@ -0,0 +1,108 @@
export function isWorkspaceDownloadRequest(path: string[]): boolean {
return (
path.length == 5 &&
path[0] === "api" &&
path[1] === "workspace" &&
path[2] === "files" &&
path[path.length - 1] === "download"
);
}
export function isRedirectStatus(status: number): boolean {
return [301, 302, 303, 307, 308].includes(status);
}
export function isTransientWorkspaceDownloadStatus(status: number): boolean {
return status === 408 || status === 429 || status >= 500;
}
export function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
export async function fetchWorkspaceDownloadOnce(
backendUrl: string,
headers: Record<string, string>,
): Promise<Response> {
const backendResponse = await fetch(backendUrl, {
method: "GET",
headers,
redirect: "manual",
});
if (!isRedirectStatus(backendResponse.status)) {
return backendResponse;
}
const location = backendResponse.headers.get("Location");
if (!location) return backendResponse;
return await fetch(location, {
method: "GET",
redirect: "follow",
});
}
export async function fetchWorkspaceDownloadWithRetry(
backendUrl: string,
headers: Record<string, string>,
maxRetries: number,
retryDelayMs: number,
): Promise<Response> {
for (let attempt = 0; attempt <= maxRetries; attempt++) {
try {
const response = await fetchWorkspaceDownloadOnce(backendUrl, headers);
if (
response.ok ||
!isTransientWorkspaceDownloadStatus(response.status) ||
attempt === maxRetries
) {
return response;
}
} catch (error) {
if (attempt === maxRetries) throw error;
}
await sleep(retryDelayMs);
}
throw new Error("Workspace download failed after retries");
}
export function getWorkspaceDownloadErrorMessage(body: unknown): string | null {
if (typeof body === "string") {
const trimmed = body.trim();
return trimmed || null;
}
if (!body || typeof body !== "object") return null;
if (
"detail" in body &&
typeof body.detail === "string" &&
body.detail.trim().length > 0
) {
return body.detail.trim();
}
if (
"error" in body &&
typeof body.error === "string" &&
body.error.trim().length > 0
) {
return body.error.trim();
}
if (
"detail" in body &&
body.detail &&
typeof body.detail === "object" &&
"message" in body.detail &&
typeof body.detail.message === "string" &&
body.detail.message.trim().length > 0
) {
return body.detail.message.trim();
}
return null;
}

View File

@@ -11,25 +11,17 @@ import { NextRequest, NextResponse } from "next/server";
export const maxDuration = 300; // 5 minutes timeout for large uploads
export const dynamic = "force-dynamic";
import {
fetchWorkspaceDownloadWithRetry,
getWorkspaceDownloadErrorMessage,
isWorkspaceDownloadRequest,
} from "./route.helpers";
function buildBackendUrl(path: string[], queryString: string): string {
const backendPath = path.join("/");
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
}
/**
* Check if this is a workspace file download request that needs binary response handling.
*/
function isWorkspaceDownloadRequest(path: string[]): boolean {
// Match pattern: api/workspace/files/{id}/download (5 segments)
return (
path.length == 5 &&
path[0] === "api" &&
path[1] === "workspace" &&
path[2] === "files" &&
path[path.length - 1] === "download"
);
}
/**
* Handle workspace file download requests with proper binary response streaming.
*/
@@ -44,17 +36,15 @@ async function handleWorkspaceDownload(
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(backendUrl, {
method: "GET",
const response = await fetchWorkspaceDownloadWithRetry(
backendUrl,
headers,
redirect: "follow", // Follow redirects to signed URLs
});
2,
500,
);
if (!response.ok) {
return NextResponse.json(
{ error: `Failed to download file: ${response.statusText}` },
{ status: response.status },
);
return await createWorkspaceDownloadErrorResponse(response);
}
// Fully buffer the response before forwarding. Passing response.body as a
@@ -81,6 +71,34 @@ async function handleWorkspaceDownload(
});
}
async function createWorkspaceDownloadErrorResponse(
response: Response,
): Promise<NextResponse> {
const contentType = response.headers.get("Content-Type")?.toLowerCase() ?? "";
try {
if (contentType.includes("application/json")) {
const body = await response.json();
return NextResponse.json(body, { status: response.status });
}
const text = await response.text();
const detail =
getWorkspaceDownloadErrorMessage(text) ||
response.statusText ||
"Failed to download file";
return NextResponse.json({ detail }, { status: response.status });
} catch {
return NextResponse.json(
{
detail: response.statusText || "Failed to download file",
},
{ status: response.status },
);
}
}
async function handleJsonRequest(
req: NextRequest,
method: string,

View File

@@ -0,0 +1,94 @@
import { describe, expect, it } from "vitest";
import {
fireEvent,
render,
screen,
waitFor,
} from "@/tests/integrations/test-utils";
import {
getPostV2UpdateUserProfileMockHandler200,
getPostV2UpdateUserProfileMockHandler422,
getPostV2UpdateUserProfileResponseMock422,
} from "@/app/api/__generated__/endpoints/store/store.msw";
import { server } from "@/mocks/mock-server";
import type { ProfileDetails } from "@/app/api/__generated__/models/profileDetails";
import { ProfileInfoForm } from "../ProfileInfoForm";
function makeProfile(overrides: Partial<ProfileDetails> = {}): ProfileDetails {
return {
name: "Initial Name",
username: "initial-user",
description: "Initial description",
links: [],
avatar_url: "",
...overrides,
} as ProfileDetails;
}
describe("ProfileInfoForm", () => {
it("renders the existing profile values into editable fields", () => {
render(<ProfileInfoForm profile={makeProfile({ name: "Hello World" })} />);
const nameInput = screen.getByTestId(
"profile-info-form-display-name",
) as HTMLInputElement;
expect(nameInput.defaultValue).toBe("Hello World");
});
it("submits the new display name to POST /api/store/profile and reflects the response", async () => {
let receivedBody: Record<string, unknown> | null = null;
server.use(
getPostV2UpdateUserProfileMockHandler200(async ({ request }) => {
receivedBody = (await request.json()) as Record<string, unknown>;
return makeProfile({ name: receivedBody?.name as string });
}),
);
render(<ProfileInfoForm profile={makeProfile({ name: "Old Name" })} />);
const nameInput = screen.getByTestId("profile-info-form-display-name");
fireEvent.change(nameInput, { target: { value: "Brand New Name" } });
fireEvent.click(screen.getByRole("button", { name: "Save changes" }));
await waitFor(() => {
expect(
receivedBody,
"POST /api/store/profile must fire when the user clicks Save",
).not.toBeNull();
});
expect(receivedBody!.name).toBe("Brand New Name");
});
it("does not silently swallow the request when the API returns 422", async () => {
let calls = 0;
server.use(
getPostV2UpdateUserProfileMockHandler422(() => {
calls += 1;
return getPostV2UpdateUserProfileResponseMock422({
detail: [
{
loc: ["body", "name"],
msg: "validation error",
type: "value_error",
},
],
});
}),
);
render(<ProfileInfoForm profile={makeProfile()} />);
const nameInput = screen.getByTestId("profile-info-form-display-name");
fireEvent.change(nameInput, { target: { value: "Anything" } });
fireEvent.click(screen.getByRole("button", { name: "Save changes" }));
await waitFor(() => {
expect(
calls,
"save click must hit the backend even when validation fails",
).toBeGreaterThan(0);
});
});
});

View File

@@ -1,3 +1,5 @@
import { render, screen } from "@testing-library/react";
import type React from "react";
import { describe, expect, it } from "vitest";
import { csvRenderer } from "./CSVRenderer";
@@ -16,6 +18,16 @@ describe("csvRenderer.canRender", () => {
it("matches .csv filename case-insensitively", () => {
expect(csvRenderer.canRender("a,b", { filename: "data.CSV" })).toBe(true);
});
it("matches TSV mime type", () => {
expect(
csvRenderer.canRender("a\tb\n1\t2", {
mimeType: "text/tab-separated-values",
}),
).toBe(true);
});
it("matches .tsv filename case-insensitively", () => {
expect(csvRenderer.canRender("a\tb", { filename: "data.TSV" })).toBe(true);
});
it("rejects non-string values", () => {
expect(csvRenderer.canRender(42, { mimeType: "text/csv" })).toBe(false);
});
@@ -64,4 +76,16 @@ describe("csvRenderer.render (parse via render output smoke)", () => {
const csv = 'name\n"She said ""hi"""';
expect(() => csvRenderer.render(csv)).not.toThrow();
});
it("renders TSV columns using tabs as the delimiter", () => {
render(
csvRenderer.render("name\tage\nAlice\t30", {
filename: "data.tsv",
}) as React.ReactElement,
);
expect(screen.getByText("name")).toBeDefined();
expect(screen.getByText("age")).toBeDefined();
expect(screen.getByText("Alice")).toBeDefined();
expect(screen.getByText("30")).toBeDefined();
});
});

View File

@@ -6,7 +6,35 @@ import {
CopyContent,
} from "../types";
function parseCSV(text: string): { headers: string[]; rows: string[][] } {
function normalizeMime(mime?: string): string | undefined {
return mime?.toLowerCase().split(";")[0]?.trim();
}
function getDelimiter(metadata?: OutputMetadata): "," | "\t" {
if (
normalizeMime(metadata?.mimeType) === "text/tab-separated-values" ||
metadata?.filename?.toLowerCase().endsWith(".tsv")
) {
return "\t";
}
return ",";
}
function getDelimitedMimeType(metadata?: OutputMetadata): string {
return getDelimiter(metadata) === "\t"
? "text/tab-separated-values"
: "text/csv";
}
function getDelimitedFallbackFilename(metadata?: OutputMetadata): string {
return getDelimiter(metadata) === "\t" ? "data.tsv" : "data.csv";
}
function parseDelimitedText(
text: string,
delimiter: "," | "\t",
): { headers: string[]; rows: string[][] } {
const normalized = text
.replace(/\r\n?/g, "\n")
.replace(/^\ufeff/, "")
@@ -32,7 +60,7 @@ function parseCSV(text: string): { headers: string[]; rows: string[][] } {
}
} else if (ch === '"') {
inQuotes = true;
} else if (ch === ",") {
} else if (ch === delimiter) {
row.push(current);
current = "";
} else if (ch === "\n") {
@@ -51,8 +79,17 @@ function parseCSV(text: string): { headers: string[]; rows: string[][] } {
return { headers, rows: rows.slice(1) };
}
function CSVTable({ value }: { value: string }) {
const { headers, rows } = useMemo(() => parseCSV(value), [value]);
function CSVTable({
value,
delimiter,
}: {
value: string;
delimiter: "," | "\t";
}) {
const { headers, rows } = useMemo(
() => parseDelimitedText(value, delimiter),
[delimiter, value],
);
const [sortCol, setSortCol] = useState<number | null>(null);
const [sortAsc, setSortAsc] = useState(true);
@@ -134,16 +171,17 @@ function CSVTable({ value }: { value: string }) {
function canRenderCSV(value: unknown, metadata?: OutputMetadata): boolean {
if (typeof value !== "string") return false;
if (metadata?.mimeType === "text/csv") return true;
const mime = normalizeMime(metadata?.mimeType);
if (mime === "text/csv" || mime === "text/tab-separated-values") {
return true;
}
if (metadata?.filename?.toLowerCase().endsWith(".csv")) return true;
if (metadata?.filename?.toLowerCase().endsWith(".tsv")) return true;
return false;
}
function renderCSV(
value: unknown,
_metadata?: OutputMetadata,
): React.ReactNode {
return <CSVTable value={String(value)} />;
function renderCSV(value: unknown, metadata?: OutputMetadata): React.ReactNode {
return <CSVTable value={String(value)} delimiter={getDelimiter(metadata)} />;
}
function getCopyContentCSV(
@@ -159,10 +197,11 @@ function getDownloadContentCSV(
metadata?: OutputMetadata,
): DownloadContent | null {
const text = String(value);
const mimeType = getDelimitedMimeType(metadata);
return {
data: new Blob([text], { type: "text/csv" }),
filename: metadata?.filename || "data.csv",
mimeType: "text/csv",
data: new Blob([text], { type: mimeType }),
filename: metadata?.filename || getDelimitedFallbackFilename(metadata),
mimeType,
};
}

View File

@@ -0,0 +1,76 @@
import { render, screen } from "@/tests/integrations/test-utils";
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
import { AgentActivityDropdown } from "../AgentActivityDropdown";
import { AgentExecutionWithInfo } from "../helpers";
import { beforeEach, describe, expect, test, vi } from "vitest";
const mockUseAgentActivityDropdown = vi.hoisted(() => vi.fn());
vi.mock("../useAgentActivityDropdown", () => ({
useAgentActivityDropdown: mockUseAgentActivityDropdown,
}));
function makeExecution(
overrides: Partial<AgentExecutionWithInfo> = {},
): AgentExecutionWithInfo {
return {
id: "exec-1",
graph_id: "graph-1",
status: AgentExecutionStatus.RUNNING,
started_at: new Date(),
ended_at: null,
user_id: "user-1",
graph_version: 1,
inputs: {},
credential_inputs: {},
nodes_input_masks: {},
preset_id: null,
stats: null,
agent_name: "Test Agent",
agent_description: "A running agent",
library_agent_id: "library-1",
...overrides,
};
}
describe("AgentActivityDropdown", () => {
beforeEach(() => {
mockUseAgentActivityDropdown.mockReturnValue({
activeExecutions: [makeExecution(), makeExecution({ id: "exec-2" })],
recentCompletions: [],
recentFailures: [],
totalCount: 2,
isReady: true,
error: null,
isOpen: false,
setIsOpen: vi.fn(),
});
});
test("shows the active execution badge count", () => {
render(<AgentActivityDropdown />);
expect(screen.getByTestId("agent-activity-badge").textContent).toContain(
"2",
);
expect(screen.getByTestId("agent-activity-button")).toBeDefined();
});
test("renders the dropdown content when open", async () => {
mockUseAgentActivityDropdown.mockReturnValue({
activeExecutions: [makeExecution()],
recentCompletions: [],
recentFailures: [],
totalCount: 1,
isReady: true,
error: null,
isOpen: true,
setIsOpen: vi.fn(),
});
render(<AgentActivityDropdown />);
expect(screen.getByTestId("agent-activity-dropdown")).toBeDefined();
expect(await screen.findByText("Test Agent")).toBeDefined();
});
});

View File

@@ -0,0 +1,97 @@
import { describe, expect, test } from "vitest";
import { setNestedProperty } from "./utils";
const testCases = [
{
name: "simple property assignment",
path: "name",
value: "John",
expected: { name: "John" },
},
{
name: "nested property with dot notation",
path: "user.settings.theme",
value: "dark",
expected: { user: { settings: { theme: "dark" } } },
},
{
name: "nested property with slash notation",
path: "user/settings/language",
value: "en",
expected: { user: { settings: { language: "en" } } },
},
{
name: "mixed dot and slash notation",
path: "user.settings/preferences.color",
value: "blue",
expected: { user: { settings: { preferences: { color: "blue" } } } },
},
{
name: "overwrite primitive with object",
path: "user.details",
value: { age: 30 },
expected: { user: { details: { age: 30 } } },
},
];
describe("setNestedProperty", () => {
for (const { name, path, value, expected } of testCases) {
test(name, () => {
const obj = {};
setNestedProperty(obj, path, value);
expect(obj).toEqual(expected);
});
}
test("throws for null object", () => {
expect(() => {
setNestedProperty(null, "test", "value");
}).toThrow("Target must be a non-null object");
});
test("throws for undefined object", () => {
expect(() => {
setNestedProperty(undefined, "test", "value");
}).toThrow("Target must be a non-null object");
});
test("throws for non-object target", () => {
expect(() => {
setNestedProperty("string", "test", "value");
}).toThrow("Target must be a non-null object");
});
test("throws for empty path", () => {
expect(() => {
setNestedProperty({}, "", "value");
}).toThrow("Path must be a non-empty string");
});
test("throws for __proto__ access", () => {
expect(() => {
setNestedProperty({}, "__proto__.malicious", "attack");
}).toThrow("Invalid property name: __proto__");
});
test("throws for constructor access", () => {
expect(() => {
setNestedProperty({}, "constructor.prototype.malicious", "attack");
}).toThrow("Invalid property name: constructor");
});
test("throws for prototype access", () => {
expect(() => {
setNestedProperty({}, "obj.prototype.malicious", "attack");
}).toThrow("Invalid property name: prototype");
});
test("prevents prototype pollution", () => {
const obj = {};
expect(() => {
setNestedProperty(obj, "__proto__.polluted", true);
}).toThrow("Invalid property name: __proto__");
expect(({} as { polluted?: boolean }).polluted).toBeUndefined();
});
});

View File

@@ -0,0 +1,100 @@
import { randomUUID } from "crypto";
import { expect, test } from "./coverage-fixture";
import { E2E_AUTH_STATES } from "./credentials/accounts";
test.use({ storageState: E2E_AUTH_STATES.parallelB });
test("api keys happy path: user can create, copy, and revoke an API key", async ({
page,
context,
}) => {
test.setTimeout(120000);
await context.grantPermissions(["clipboard-read", "clipboard-write"]);
const keyName = `E2E CLI Key ${randomUUID().slice(0, 8)}`;
await page.goto("/profile/api-keys");
await expect(page).toHaveURL(/\/profile\/api-keys/);
await expect(
page.getByText(
"Manage your AutoGPT Platform API keys for programmatic access",
),
).toBeVisible();
await page.getByRole("button", { name: "Create Key" }).click();
await page.getByLabel("Name").fill(keyName);
const executeGraphCheckbox = page.getByRole("checkbox", {
name: /EXECUTE_GRAPH/i,
});
const executeGraphChecked =
(await executeGraphCheckbox.getAttribute("aria-checked")) === "true";
if (!executeGraphChecked) {
await executeGraphCheckbox.click();
}
await expect(executeGraphCheckbox).toHaveAttribute("aria-checked", "true");
await page.getByRole("button", { name: "Create" }).click();
const secretDialog = page.getByRole("dialog", {
name: "AutoGPT Platform API Key Created",
});
await expect
.poll(
async () => {
if (await secretDialog.isVisible().catch(() => false)) {
return "created";
}
const creationFailed = await page
.getByText("Failed to create AutoGPT Platform API key")
.isVisible()
.catch(() => false);
if (creationFailed) {
return "failed";
}
return "pending";
},
{
timeout: 30000,
message:
"API key creation should either open the created-key dialog or surface an explicit failure toast",
},
)
.toBe("created");
await expect(secretDialog).toBeVisible();
const createdSecret = (
(await secretDialog.locator("code").textContent()) ?? ""
).trim();
expect(createdSecret.length).toBeGreaterThan(0);
await secretDialog.getByRole("button").first().click();
await expect(page.getByText("Copied", { exact: true })).toBeVisible({
timeout: 15000,
});
await expect
.poll(() => page.evaluate(() => navigator.clipboard.readText()), {
timeout: 10000,
})
.toBe(createdSecret);
await secretDialog.getByRole("button", { name: "Close" }).first().click();
const createdKeyRow = page
.getByTestId("api-key-row")
.filter({ hasText: keyName })
.first();
await expect(createdKeyRow).toBeVisible({ timeout: 15000 });
await createdKeyRow.getByTestId("api-key-actions").click();
await page.getByRole("menuitem", { name: "Revoke" }).click();
await expect(
page.getByText("AutoGPT Platform API key revoked successfully"),
).toBeVisible({ timeout: 15000 });
await expect(
page.getByTestId("api-key-row").filter({ hasText: keyName }),
).toHaveCount(0);
});

View File

@@ -0,0 +1,158 @@
import { expect, test } from "./coverage-fixture";
import { getSeededTestUser } from "./credentials/accounts";
import { BuildPage } from "./pages/build.page";
import { LoginPage } from "./pages/login.page";
import {
completeOnboardingWizard,
skipOnboardingIfPresent,
} from "./utils/onboarding";
import { signupTestUser } from "./utils/signup";
test("auth happy path: user can sign up with a fresh account", async ({
page,
}) => {
test.setTimeout(60000);
await signupTestUser(page, undefined, undefined, false);
await expect(page).toHaveURL(/\/onboarding/);
await expect(page.getByText("Welcome to AutoGPT")).toBeVisible();
});
test("auth happy path: user can sign up, enter the app, and log out", async ({
page,
}) => {
test.setTimeout(90000);
await signupTestUser(page, undefined, undefined, false);
await expect(page).toHaveURL(/\/onboarding/);
await expect(page.getByText("Welcome to AutoGPT")).toBeVisible();
await skipOnboardingIfPresent(page, "/marketplace");
await expect(page).toHaveURL(/\/marketplace/);
await expect(page.getByTestId("profile-popout-menu-trigger")).toBeVisible();
await page.getByTestId("profile-popout-menu-trigger").click();
await page.getByRole("button", { name: "Log out" }).click();
await expect(page).toHaveURL(/\/login/);
await page.goto("/library");
await expect(page).toHaveURL(/\/login\?next=%2Flibrary/);
});
test("auth happy path: seeded user can log in", async ({ page }) => {
test.setTimeout(60000);
const testUser = getSeededTestUser("smokeAuth");
const loginPage = new LoginPage(page);
await page.goto("/login");
await loginPage.login(testUser.email, testUser.password);
await expect(page).toHaveURL(/\/marketplace/);
await expect(page.getByTestId("profile-popout-menu-trigger")).toBeVisible();
});
test("auth happy path: seeded user can log out and protected routes redirect to login", async ({
page,
}) => {
test.setTimeout(60000);
const testUser = getSeededTestUser("primary");
const loginPage = new LoginPage(page);
await page.goto("/login");
await loginPage.login(testUser.email, testUser.password);
await expect(page).toHaveURL(/\/marketplace/);
await page.getByTestId("profile-popout-menu-trigger").click();
await page.getByRole("button", { name: "Log out" }).click();
await expect(page).toHaveURL(/\/login/, { timeout: 15000 });
await page.goto("/profile");
await expect(page).toHaveURL(/\/login\?next=%2Fprofile/);
});
test("auth happy path: user can complete onboarding and land in the app", async ({
page,
}) => {
test.setTimeout(60000);
await signupTestUser(page, undefined, undefined, false);
await expect(page).toHaveURL(/\/onboarding/);
await completeOnboardingWizard(page, {
name: "Smoke User",
role: "Engineering",
painPoints: ["Research", "Reports & data"],
});
await expect(page).toHaveURL(/\/copilot/);
await expect(page.getByTestId("profile-popout-menu-trigger")).toBeVisible();
});
test("auth happy path: multi-tab logout clears shared builder sessions", async ({
context,
}) => {
// Two pages + builder load + logout sequence justifies a higher timeout
test.setTimeout(90000);
const consoleErrors: string[] = [];
const page1 = await context.newPage();
const page2 = await context.newPage();
const buildPage = new BuildPage(page1);
const recordWebSocketErrors =
(label: string) => (msg: { type: () => string; text: () => string }) => {
if (msg.type() === "error" && msg.text().includes("WebSocket")) {
consoleErrors.push(`${label}: ${msg.text()}`);
}
};
page1.on("console", recordWebSocketErrors("page1"));
page2.on("console", recordWebSocketErrors("page2"));
await signupTestUser(page1, undefined, undefined, false);
await expect(page1).toHaveURL(/\/onboarding/);
await skipOnboardingIfPresent(page1, "/build");
await page1.goto("/build");
await expect(page1).toHaveURL(/\/build/);
await buildPage.closeTutorial();
await expect(page1.getByTestId("profile-popout-menu-trigger")).toBeVisible();
await page2.goto("/build");
await expect(page2).toHaveURL(/\/build/);
await expect(page2.getByTestId("profile-popout-menu-trigger")).toBeVisible();
await page1.getByTestId("profile-popout-menu-trigger").click();
await page1.getByRole("button", { name: "Log out" }).click();
await expect(page1).toHaveURL(/\/login/);
await page2.reload();
await expect(page2).toHaveURL(/\/login\?next=%2Fbuild/);
await expect(page2.getByTestId("profile-popout-menu-trigger")).toBeHidden();
expect(consoleErrors).toHaveLength(0);
// Prove the auth token is actually gone, not just the UI hidden. Supabase
// overwrites the cookie on signout with an empty value + past expiry
// rather than deleting it. An assertion that is silently skipped when the
// cookie is missing under the expected name would hide a real regression,
// so we assert on every non-empty sb-*auth-token* cookie explicitly.
const cookiesAfterLogout = await context.cookies();
const authCookies = cookiesAfterLogout.filter(
(c) => c.name.startsWith("sb-") && c.name.includes("auth-token"),
);
for (const cookie of authCookies) {
expect(
cookie.value,
`supabase auth cookie ${cookie.name} must be empty after logout`,
).toBe("");
}
await page1.close();
await page2.close();
});

View File

@@ -0,0 +1,83 @@
import { expect, test } from "./coverage-fixture";
import { E2E_AUTH_STATES } from "./credentials/accounts";
import { BuildPage } from "./pages/build.page";
test.use({ storageState: E2E_AUTH_STATES.builder });
test("builder happy path: user can walk through the builder tutorial and cancel midway, persisting canceled state", async ({
page,
}) => {
test.setTimeout(180000);
const buildPage = new BuildPage(page);
await buildPage.startTutorial();
await buildPage.walkWelcomeToBlockMenu();
await buildPage.walkSearchAndAddCalculator();
await buildPage.cancelTutorial();
expect(await buildPage.getTutorialStateFromStorage()).toBe("canceled");
expect(await buildPage.getNodeCount()).toBeGreaterThanOrEqual(1);
});
test("builder happy path: user can skip the builder tutorial from the welcome step", async ({
page,
}) => {
test.setTimeout(60000);
const buildPage = new BuildPage(page);
await buildPage.startTutorial();
await buildPage.skipTutorialFromWelcome();
});
test("builder happy path: user can create a simple agent in builder with core blocks", async ({
page,
}) => {
test.setTimeout(120000);
const buildPage = new BuildPage(page);
await buildPage.open();
await buildPage.addSimpleAgentBlocks();
await expect(buildPage.getNodeLocator()).toHaveCount(2);
await expect(
buildPage
.getNodeLocator(0)
.locator('input[placeholder="Enter string value..."]'),
).toHaveValue("smoke-value");
await expect(buildPage.getNodeTextInput("Add to Dictionary", 0)).toHaveValue(
"smoke-key",
);
await expect(buildPage.getNodeTextInput("Add to Dictionary", 1)).toHaveValue(
"smoke-value",
);
});
test("builder happy path: user can save the created agent", async ({
page,
}) => {
test.setTimeout(120000);
const buildPage = new BuildPage(page);
await buildPage.createAndSaveSimpleAgent("Smoke Save Agent");
await expect(page).toHaveURL(/flowID=/);
expect(await buildPage.isRunButtonEnabled()).toBeTruthy();
});
test("builder happy path: user can run the saved agent from builder and see execution state", async ({
page,
}) => {
test.setTimeout(120000);
const buildPage = new BuildPage(page);
await buildPage.createAndSaveSimpleAgent("Smoke Run Agent");
await buildPage.startRun();
await expect(
page.locator('[data-id="stop-graph-button"], [data-id="run-graph-button"]'),
).toBeVisible({ timeout: 15000 });
await expect
.poll(() => buildPage.getExecutionState(), { timeout: 15000 })
.not.toBe("unknown");
});

View File

@@ -0,0 +1,44 @@
import { expect, test } from "./coverage-fixture";
import { E2E_AUTH_STATES } from "./credentials/accounts";
import { CopilotPage } from "./pages/copilot.page";
test.use({ storageState: E2E_AUTH_STATES.marketplace });
test("copilot happy path: user can create a deterministic AutoPilot session and keep it after reload", async ({
page,
}) => {
test.setTimeout(120000);
const copilotPage = new CopilotPage(page);
await copilotPage.open();
const sessionId = await copilotPage.createSessionViaApi();
await copilotPage.open(sessionId);
await copilotPage.waitForChatInput();
await page.reload();
await page.waitForLoadState("domcontentloaded");
await copilotPage.dismissNotificationPrompt();
await expect
.poll(() => new URL(page.url()).searchParams.get("sessionId"), {
timeout: 15000,
})
.toBe(sessionId);
await copilotPage.waitForChatInput();
// Sending a message must render the user's prompt in the conversation
// immediately. This catches a regression where the chat input accepts
// text but Enter is a no-op, without depending on knowing the exact
// backend endpoint name (which has shifted historically).
const userPrompt = `ping from e2e ${Date.now().toString().slice(-6)}`;
const chatInput = copilotPage.getChatInput();
await chatInput.fill(userPrompt);
await chatInput.press("Enter");
await expect(
page.getByText(userPrompt, { exact: false }).first(),
"user's typed prompt must appear in the chat after pressing Enter",
).toBeVisible({ timeout: 15000 });
});

View File

@@ -0,0 +1,85 @@
import path from "path";
export const SEEDED_TEST_PASSWORD =
process.env.SEEDED_TEST_PASSWORD || "testpassword123";
export const SEEDED_USER_POOL_VERSION = "2.0.0";
export const SEEDED_TEST_ACCOUNTS = {
primary: {
key: "primary",
email: "test123@example.com",
password: SEEDED_TEST_PASSWORD,
},
smokeAuth: {
key: "smokeAuth",
email: "e2e.qa.auth@example.com",
password: SEEDED_TEST_PASSWORD,
},
smokeBuilder: {
key: "smokeBuilder",
email: "e2e.qa.builder@example.com",
password: SEEDED_TEST_PASSWORD,
},
smokeLibrary: {
key: "smokeLibrary",
email: "e2e.qa.library@example.com",
password: SEEDED_TEST_PASSWORD,
},
smokeMarketplace: {
key: "smokeMarketplace",
email: "e2e.qa.marketplace@example.com",
password: SEEDED_TEST_PASSWORD,
},
smokeSettings: {
key: "smokeSettings",
email: "e2e.qa.settings@example.com",
password: SEEDED_TEST_PASSWORD,
},
parallelA: {
key: "parallelA",
email: "e2e.qa.parallel.a@example.com",
password: SEEDED_TEST_PASSWORD,
},
parallelB: {
key: "parallelB",
email: "e2e.qa.parallel.b@example.com",
password: SEEDED_TEST_PASSWORD,
},
} as const;
export type SeededTestAccountKey = keyof typeof SEEDED_TEST_ACCOUNTS;
export type SeededTestAccount =
(typeof SEEDED_TEST_ACCOUNTS)[SeededTestAccountKey];
export const SEEDED_TEST_USERS = Object.values(SEEDED_TEST_ACCOUNTS);
export const SEEDED_AUTH_STATE_ACCOUNT_KEYS = [
"smokeBuilder",
"smokeLibrary",
"smokeMarketplace",
"smokeSettings",
"parallelA",
"parallelB",
] as const;
export const AUTH_DIRECTORY = path.resolve(process.cwd(), ".auth");
export function getAuthStatePath(accountKey: SeededTestAccountKey) {
return path.join(AUTH_DIRECTORY, "states", `${accountKey}.json`);
}
export const E2E_AUTH_STATES = {
builder: getAuthStatePath("smokeBuilder"),
library: getAuthStatePath("smokeLibrary"),
marketplace: getAuthStatePath("smokeMarketplace"),
settings: getAuthStatePath("smokeSettings"),
parallelA: getAuthStatePath("parallelA"),
parallelB: getAuthStatePath("parallelB"),
} as const;
export const SMOKE_AUTH_STATES = E2E_AUTH_STATES;
export function getSeededTestUser(
accountKey: SeededTestAccountKey = "primary",
): SeededTestAccount {
return SEEDED_TEST_ACCOUNTS[accountKey];
}

View File

@@ -0,0 +1,27 @@
import { getSeededTestUser } from "./accounts";
// E2E Test Credentials and Constants
export const TEST_CREDENTIALS = getSeededTestUser("primary");
export function getTestUserWithLibraryAgents() {
return TEST_CREDENTIALS;
}
// Dummy constant to help developers identify agents that don't need input
export const DummyInput = "DummyInput";
// This will be used for testing agent submission for test123@example.com
export const TEST_AGENT_DATA = {
name: "E2E Calculator Agent",
description:
"A deterministic marketplace agent built from Calculator and Agent Output blocks for frontend E2E coverage.",
image_urls: [
"https://picsum.photos/seed/e2e-marketplace-1/200/300",
"https://picsum.photos/seed/e2e-marketplace-2/200/301",
"https://picsum.photos/seed/e2e-marketplace-3/200/302",
],
video_url: "https://www.youtube.com/watch?v=test123",
sub_heading: "A deterministic calculator agent for PR E2E coverage",
categories: ["test", "demo", "frontend"],
changes_summary: "Initial deterministic calculator submission",
} as const;

View File

@@ -0,0 +1,23 @@
export function buildCookieConsentStorageState(
origin: string = "http://localhost:3000",
) {
return {
cookies: [],
origins: [
{
origin,
localStorage: [
{
name: "autogpt_cookie_consent",
value: JSON.stringify({
hasConsented: true,
timestamp: Date.now(),
analytics: true,
monitoring: true,
}),
},
],
},
],
};
}

View File

@@ -0,0 +1,49 @@
import { FullConfig } from "@playwright/test";
import {
ensureSeededAuthStates,
getInvalidSeededAuthStateKeys,
} from "./utils/auth";
function resolveBaseURL(config: FullConfig) {
const configuredBaseURL =
config.projects[0]?.use?.baseURL ?? "http://localhost:3000";
if (typeof configuredBaseURL !== "string") {
throw new Error(
`Playwright baseURL must be a string during global setup. Received ${String(
configuredBaseURL,
)}.`,
);
}
return configuredBaseURL;
}
async function globalSetup(config: FullConfig) {
console.log("🚀 Starting global test setup...");
try {
const baseURL = resolveBaseURL(config);
const invalidKeys = await getInvalidSeededAuthStateKeys(baseURL);
if (invalidKeys.length === 0) {
console.log("♻️ Reusing stored seeded auth states");
return;
}
console.log(
`🔐 Refreshing seeded auth states for: ${invalidKeys.join(", ")}`,
);
await ensureSeededAuthStates(baseURL);
console.log("✅ Global setup completed successfully!");
} catch (error) {
console.error("❌ Global setup failed:", error);
console.error(
"💡 Run backend/test/e2e_test_data.py to seed the deterministic Playwright accounts before retrying.",
);
throw error;
}
}
export default globalSetup;

View File

@@ -0,0 +1,559 @@
import path from "path";
import type { Page } from "@playwright/test";
import { expect, test } from "./coverage-fixture";
import { E2E_AUTH_STATES } from "./credentials/accounts";
import { BuildPage, createUniqueAgentName } from "./pages/build.page";
import {
clickRunButton,
dismissFeedbackDialog,
getActiveItemId,
importAgentFromFile,
LibraryPage,
} from "./pages/library.page";
test.use({ storageState: E2E_AUTH_STATES.library });
const TEST_AGENT_PATH = path.resolve(__dirname, "assets", "testing_agent.json");
const CALCULATOR_BLOCK_ID = "b1ab9b19-67a6-406d-abf5-2dba76d00c79";
const AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4";
const STOPPED_RUN_STATUSES = new Set([
"terminated",
"failed",
"incomplete",
"completed",
]);
type UploadedGraphNode = {
id: string;
block_id: string;
input_default: Record<string, unknown>;
metadata: {
position: {
x: number;
y: number;
};
};
input_links: unknown[];
output_links: unknown[];
};
function createLongRunningCalculatorGraph(
agentName: string,
calculatorCount: number = 150,
) {
const nodes: UploadedGraphNode[] = Array.from(
{ length: calculatorCount },
(_, index) => ({
id: `calc-${index + 1}`,
block_id: CALCULATOR_BLOCK_ID,
input_default:
index === 0
? {
operation: "Add",
a: 1,
b: 1,
round_result: false,
}
: {
operation: "Add",
b: 1,
round_result: false,
},
metadata: {
position: { x: 320 * index, y: 120 },
},
input_links: [],
output_links: [],
}),
);
const links = Array.from({ length: calculatorCount - 1 }, (_, index) => ({
source_id: `calc-${index + 1}`,
sink_id: `calc-${index + 2}`,
source_name: "result",
sink_name: "a",
}));
nodes.push({
id: "final-output",
block_id: AGENT_OUTPUT_BLOCK_ID,
input_default: {
name: "Final result",
description: "Long-running calculator chain output",
},
metadata: {
position: { x: 320 * calculatorCount, y: 120 },
},
input_links: [],
output_links: [],
});
links.push({
source_id: `calc-${calculatorCount}`,
sink_id: "final-output",
source_name: "result",
sink_name: "value",
});
return {
name: agentName,
description:
"Deterministic long-running calculator chain for runner stop coverage",
is_active: true,
nodes,
links,
};
}
async function createLongRunningSavedAgent(
page: Page,
agentName: string,
): Promise<{ graphId: string; graphVersion: number }> {
const response = await page.request.post("/api/proxy/api/graphs", {
data: {
graph: createLongRunningCalculatorGraph(agentName),
source: "upload",
},
});
expect(response.ok(), "expected graph creation API request to succeed").toBe(
true,
);
const body = (await response.json()) as {
id?: string;
version?: number;
data?: { id?: string; version?: number };
};
expect(
body.data?.id ?? body.id,
"graph creation should return a graph id",
).toBeTruthy();
return {
graphId: String(body.data?.id ?? body.id),
graphVersion: Number(body.data?.version ?? body.version ?? 1),
};
}
async function createDeterministicCalculatorSavedAgent(
page: Page,
agentName: string,
outputName: string,
): Promise<void> {
const response = await page.request.post("/api/proxy/api/graphs", {
data: {
graph: {
name: agentName,
description:
"Deterministic calculator output for run-result assertions",
is_active: true,
nodes: [
{
id: "calc-1",
block_id: CALCULATOR_BLOCK_ID,
input_default: {
operation: "Add",
a: 1,
b: 1,
round_result: false,
},
metadata: {
position: { x: 120, y: 160 },
},
input_links: [],
output_links: [],
},
{
id: "final-output",
block_id: AGENT_OUTPUT_BLOCK_ID,
input_default: {
name: outputName,
description: "Deterministic result output",
},
metadata: {
position: { x: 520, y: 160 },
},
input_links: [],
output_links: [],
},
],
links: [
{
source_id: "calc-1",
sink_id: "final-output",
source_name: "result",
sink_name: "value",
},
],
},
source: "upload",
},
});
expect(
response.ok(),
"expected deterministic calculator graph creation API request to succeed",
).toBe(true);
}
async function getExecutionStatusFromApi(
page: Page,
graphId: string,
runId: string,
): Promise<string> {
const response = await page.request.get(
`/api/proxy/api/graphs/${graphId}/executions/${runId}`,
);
expect(response.ok(), "execution details API should succeed").toBe(true);
const body = (await response.json()) as { status?: string };
return body.status?.toLowerCase() ?? "unknown";
}
async function createAndSaveDeterministicOutputAgent(
page: Page,
prefix: string,
): Promise<{ agentName: string; expectedOutput: string; outputName: string }> {
const buildPage = new BuildPage(page);
const agentName = createUniqueAgentName(prefix);
const expectedOutput = `e2e-output-${Date.now()}`;
const outputName = `e2e-result-${Date.now()}`;
await buildPage.open();
await buildPage.addBlockByClick("Store Value");
await buildPage.waitForNodeOnCanvas(1);
await buildPage.fillBlockInputByPlaceholder(
"Enter string value...",
expectedOutput,
0,
);
await buildPage.addBlockByClick("Agent Output");
await buildPage.waitForNodeOnCanvas(2);
await buildPage.connectNodes(0, 1);
await buildPage.fillLastNodeTextInput("Agent Output", outputName);
await buildPage.saveAgent(
agentName,
"Deterministic output agent for library run verification",
);
await buildPage.waitForSaveComplete();
await buildPage.waitForSaveButton();
return { agentName, expectedOutput, outputName };
}
test("library happy path: user can import an agent file into Library", async ({
page,
}) => {
test.setTimeout(120000);
const { importedAgent } = await importAgentFromFile(
page,
TEST_AGENT_PATH,
createUniqueAgentName("E2E Import Agent"),
);
expect(importedAgent.name).toContain("E2E Import Agent");
});
test("library happy path: user can open the imported or saved agent from Library in builder", async ({
page,
}) => {
test.setTimeout(120000);
const { libraryPage, importedAgent } = await importAgentFromFile(
page,
TEST_AGENT_PATH,
createUniqueAgentName("E2E Open Agent"),
);
// Register the popup listener before clicking so we don't miss a fast open.
// A short timeout covers the case where the link opens in the current tab.
const popupPromise = page
.context()
.waitForEvent("page", { timeout: 10000 })
.catch(() => null);
await libraryPage.clickOpenInBuilder(importedAgent);
const builderPage = (await popupPromise) ?? page;
await builderPage.waitForLoadState("domcontentloaded");
await expect(builderPage).toHaveURL(/\/build/);
const importedBuildPage = new BuildPage(builderPage);
await importedBuildPage.waitForNodeOnCanvas();
expect(await importedBuildPage.getNodeCount()).toBeGreaterThan(0);
if (builderPage !== page) {
await builderPage.close();
}
});
test("library happy path: user can start and stop a saved task from runner UI", async ({
page,
}) => {
test.setTimeout(180000);
const agentName = createUniqueAgentName("E2E Stop Task Agent");
const { graphId } = await createLongRunningSavedAgent(page, agentName);
const libraryPage = new LibraryPage(page);
await libraryPage.openSavedAgent(agentName);
await clickRunButton(page);
await expect
.poll(() => getActiveItemId(page), { timeout: 45000 })
.not.toBe(null);
const runId = getActiveItemId(page);
expect(runId, "run id should be present after starting task").toBeTruthy();
await expect
.poll(() => libraryPage.getRunStatus(), { timeout: 45000 })
.toBe("running");
const stopTaskButton = page.getByRole("button", { name: /Stop task/i });
await expect(stopTaskButton).toBeVisible({ timeout: 30000 });
const stopResponsePromise = page.waitForResponse(
(response) =>
response.request().method() === "POST" &&
response
.url()
.includes(`/api/graphs/${graphId}/executions/${runId}/stop`),
{ timeout: 15000 },
);
await stopTaskButton.click();
const stopResponse = await stopResponsePromise;
expect(stopResponse.ok(), "stop run API should succeed").toBe(true);
await expect(page.getByText("Run stopped")).toBeVisible({ timeout: 15000 });
await expect
.poll(
async () => {
const status = await getExecutionStatusFromApi(
page,
graphId,
String(runId),
);
return STOPPED_RUN_STATUSES.has(status) ? status : "running";
},
{ timeout: 45000 },
)
.not.toBe("running");
});
test("library happy path: user can run a saved agent and verify expected output", async ({
page,
}) => {
test.setTimeout(150000);
const agentName = createUniqueAgentName("E2E Expected Output Agent");
const outputName = `e2e-result-${Date.now()}`;
await createDeterministicCalculatorSavedAgent(page, agentName, outputName);
const libraryPage = new LibraryPage(page);
await libraryPage.openSavedAgent(agentName);
await clickRunButton(page);
await libraryPage.waitForRunToComplete();
await dismissFeedbackDialog(page);
await libraryPage.assertRunProducedOutput();
await libraryPage.assertRunOutputValue(outputName, /^2(?:\.0+)?$/);
await expect
.poll(() => libraryPage.getRunStatus(), { timeout: 15000 })
.toBe("completed");
});
test("library happy path: user can edit a saved agent from Library and keep changes after refresh", async ({
page,
}) => {
test.setTimeout(150000);
const { agentName } = await createAndSaveDeterministicOutputAgent(
page,
"E2E Edit Persist Agent",
);
const editedValue = `edited-value-${Date.now()}`;
const libraryPage = new LibraryPage(page);
await page.goto("/library");
await libraryPage.waitForAgentsToLoad();
await libraryPage.searchAgents(agentName);
await libraryPage.waitForAgentsToLoad();
const agentCard = page
.getByTestId("library-agent-card")
.filter({ hasText: agentName })
.first();
await expect(agentCard).toBeVisible({ timeout: 15000 });
const popupPromise = page
.context()
.waitForEvent("page", { timeout: 10000 })
.catch(() => null);
await agentCard
.getByTestId("library-agent-card-open-in-builder-link")
.first()
.click();
const builderPage = (await popupPromise) ?? page;
const builderTabPage = new BuildPage(builderPage);
await builderTabPage.waitForNodeOnCanvas();
await builderTabPage.fillBlockInputByPlaceholder(
"Enter string value...",
editedValue,
0,
);
await builderPage.getByTestId("save-control-save-button").click();
const saveAgentButton = builderPage.getByRole("button", {
name: "Save Agent",
});
if (await saveAgentButton.isVisible({ timeout: 3000 }).catch(() => false)) {
await expect(saveAgentButton).toBeEnabled({ timeout: 10000 });
await saveAgentButton.click();
await expect(saveAgentButton).toBeHidden({ timeout: 15000 });
}
await builderPage.reload();
await builderTabPage.waitForNodeOnCanvas();
await expect(
builderTabPage
.getNodeLocator(0)
.locator('input[placeholder="Enter string value..."]'),
).toHaveValue(editedValue);
if (builderPage !== page) {
await builderPage.close();
}
});
test("library happy path: user can rerun a completed task from the Library agent page", async ({
page,
}) => {
test.setTimeout(120000);
const buildPage = new BuildPage(page);
const { agentName } =
await buildPage.createAndSaveSimpleAgent("E2E Rerun Agent");
const libraryPage = new LibraryPage(page);
await libraryPage.openSavedAgent(agentName);
await clickRunButton(page);
await libraryPage.waitForRunToComplete();
await dismissFeedbackDialog(page);
const rerunTaskButton = page.getByRole("button", { name: /Rerun task/i });
await expect(rerunTaskButton).toBeVisible({ timeout: 45000 });
await expect
.poll(() => getActiveItemId(page), { timeout: 45000 })
.not.toBe(null);
const initialRunId = getActiveItemId(page);
expect(initialRunId).toBeTruthy();
await rerunTaskButton.click();
await expect(page.getByText("Run started", { exact: true })).toBeVisible({
timeout: 15000,
});
await expect
.poll(() => getActiveItemId(page), { timeout: 45000 })
.not.toBe(initialRunId);
await libraryPage.waitForRunToComplete();
// Simple agent has no AgentOutputBlock — verify run completion only.
const runStatus = await libraryPage.getRunStatus();
expect(runStatus).toBe("completed");
});
test("library happy path: user can delete a completed task from the run sidebar", async ({
page,
}) => {
test.setTimeout(120000);
const buildPage = new BuildPage(page);
const { agentName } = await buildPage.createAndSaveSimpleAgent(
"E2E Delete Task Agent",
);
const libraryPage = new LibraryPage(page);
await libraryPage.openSavedAgent(agentName);
await clickRunButton(page);
await libraryPage.waitForRunToComplete();
await dismissFeedbackDialog(page);
// Open the per-task actions dropdown ("More actions" three-dot button)
// and use the menu's Delete task option to remove the run.
const moreActionsButton = page
.getByRole("button", { name: "More actions" })
.first();
await expect(moreActionsButton).toBeVisible({ timeout: 15000 });
await moreActionsButton.click();
await page.getByRole("menuitem", { name: /Delete( this)? task/i }).click();
const confirmDialog = page.getByRole("dialog", { name: /Delete task/i });
await expect(confirmDialog).toBeVisible({ timeout: 10000 });
await confirmDialog.getByRole("button", { name: /^Delete Task$/ }).click();
// Toast confirms the backend actually deleted (not just dialog closed).
await expect(page.getByText("Task deleted", { exact: true })).toBeVisible({
timeout: 15000,
});
// Sidebar should drop the only run, returning the page to initial
// task-entry state.
await expect(
page.getByRole("button", { name: /^(Setup your task|New task)$/i }),
).toBeVisible({ timeout: 15000 });
});
test("library happy path: user can open the agent in builder from the exact runner customise-agent path", async ({
page,
context,
}) => {
test.setTimeout(120000);
const buildPage = new BuildPage(page);
const { agentName } = await buildPage.createAndSaveSimpleAgent(
"E2E View Task Agent",
);
const libraryPage = new LibraryPage(page);
await libraryPage.openSavedAgent(agentName);
await clickRunButton(page);
await libraryPage.waitForRunToComplete();
await dismissFeedbackDialog(page);
// The "View task details" eye-icon button on a completed run opens the
// agent in the builder in a new tab. This exercises the runner → builder
// navigation that QA item #22 ("Customise Agent" from Runner UI) covers.
const selectedRunId = getActiveItemId(page);
expect(selectedRunId).toBeTruthy();
const viewTaskButton = page
.locator('[aria-label="View task details"]')
.first();
await expect(viewTaskButton).toBeVisible({ timeout: 15000 });
const customiseAgentHref = await viewTaskButton.getAttribute("href");
expect(customiseAgentHref).toContain("flowID=");
expect(customiseAgentHref).toContain("flowVersion=");
expect(customiseAgentHref).toContain(`flowExecutionID=${selectedRunId}`);
const popupPromise = context.waitForEvent("page", { timeout: 15000 });
await viewTaskButton.click();
const builderTab = await popupPromise;
await builderTab.waitForLoadState("domcontentloaded");
await expect(builderTab).toHaveURL(/\/build/);
await expect(builderTab).toHaveURL(
new RegExp(`flowExecutionID=${selectedRunId}`),
);
// Verify the builder canvas actually rendered with the agent's nodes —
// a navigation that lands on /build but never paints the graph would
// otherwise pass on URL alone.
const builderTabPage = new BuildPage(builderTab);
await builderTabPage.waitForNodeOnCanvas();
expect(await builderTabPage.getNodeCount()).toBeGreaterThan(0);
await builderTab.close();
});

View File

@@ -0,0 +1,48 @@
import { expect, test } from "./coverage-fixture";
import { E2E_AUTH_STATES } from "./credentials/accounts";
import {
clickRunButton,
dismissFeedbackDialog,
LibraryPage,
} from "./pages/library.page";
import { MarketplacePage } from "./pages/marketplace.page";
test.use({ storageState: E2E_AUTH_STATES.marketplace });
test("marketplace happy path: user can browse Marketplace and open an agent detail page", async ({
page,
}) => {
test.setTimeout(90000);
const marketplacePage = new MarketplacePage(page);
await marketplacePage.openFeaturedAgent();
await expect(page.getByTestId("agent-description")).toBeVisible();
});
test("marketplace happy path: user can add a Marketplace agent to Library and run it", async ({
page,
}) => {
test.setTimeout(120000);
const marketplacePage = new MarketplacePage(page);
await marketplacePage.openRunnableAgent();
const agentName = await page.getByTestId("agent-title").innerText();
await page.getByTestId("agent-add-library-button").click();
await expect(page.getByText("Redirecting to your library...")).toBeVisible();
await expect(page).toHaveURL(/\/library\/agents\//);
const libraryPage = new LibraryPage(page);
await libraryPage.openSavedAgent(agentName);
await clickRunButton(page);
await libraryPage.waitForRunToComplete();
await dismissFeedbackDialog(page);
const runStatus = await libraryPage.getRunStatus();
expect(runStatus).toBe("completed");
await libraryPage.assertRunProducedOutput();
await libraryPage.assertFirstRunOutputValue(/^\d+(?:\.0+)?$/);
});

Some files were not shown because too many files have changed in this diff Show More